diff --git a/AGENTS.md b/AGENTS.md index 1d4593b..e1445db 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -57,6 +57,7 @@ A change is done only when: - โœ… CI passes on all target OSes ## 4. Commit / PR hygiene +- **NEVER commit directly without presenting the full diff to the user for review first.** Always show `git diff` output and wait for explicit approval before running `git commit`. No exceptions. - Use clear commit messages (imperative, scoped). - Avoid mixing unrelated refactors with feature work. - If you changed any persistent format (db header, page layout, WAL frame format, postings format): diff --git a/README.md b/README.md index 4731e6d..461fd84 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ DecentDB is a embedded relational database engine focused on **durable writes**, - ๐Ÿ”Ž **Trigram Index** - Fast text search for `LIKE '%pattern%'` queries - ๐Ÿงช **Comprehensive Testing** - Unit tests, property tests, crash injection, and differential testing - ๐Ÿ”„ **Foreign Key Constraints** - Automatic indexing and referential integrity enforcement -- ๐Ÿ“Š **Rich Query Support** - Aggregates, subqueries (FROM, EXISTS, scalar), UPSERT, set operations, and scalar functions (string, math, UUID) +- ๐Ÿ“Š **Rich Query Support** - Aggregates, subqueries (FROM, EXISTS, scalar), UPSERT, set operations, and scalar functions (string, math, UUID, JSON) - โšก **Triggers** - AFTER and INSTEAD OF triggers for complex logic - ๐Ÿ“ฆ **Single File Database** - Portable database stored in a single file - ๐ŸŒ **Cross-Platform** - Runs on Linux, macOS, and Windows diff --git a/benchmarks/embedded_compare/assets/decentdb-benchmarks.png b/benchmarks/embedded_compare/assets/decentdb-benchmarks.png index 74a2e5b..ff82bf7 100644 Binary files a/benchmarks/embedded_compare/assets/decentdb-benchmarks.png and b/benchmarks/embedded_compare/assets/decentdb-benchmarks.png differ diff --git a/benchmarks/embedded_compare/assets/decentdb-benchmarks.svg b/benchmarks/embedded_compare/assets/decentdb-benchmarks.svg index 843dfbd..37e35ee 100644 --- a/benchmarks/embedded_compare/assets/decentdb-benchmarks.svg +++ b/benchmarks/embedded_compare/assets/decentdb-benchmarks.svg @@ -6,7 +6,7 @@ - 2026-02-21T14:23:58.539289 + 2026-02-22T15:32:19.001293 image/svg+xml @@ -30,188 +30,188 @@ z - - +" clip-path="url(#p290a25212b)" style="fill: #1f77b4"/> - +" clip-path="url(#p290a25212b)" style="fill: #1f77b4"/> - +" clip-path="url(#p290a25212b)" style="fill: #1f77b4"/> - +" clip-path="url(#p290a25212b)" style="fill: #1f77b4"/> - +" clip-path="url(#p290a25212b)" style="fill: #ff7f0e"/> - +" clip-path="url(#p290a25212b)" style="fill: #ff7f0e"/> - +" clip-path="url(#p290a25212b)" style="fill: #ff7f0e"/> - +" clip-path="url(#p290a25212b)" style="fill: #ff7f0e"/> - +" clip-path="url(#p290a25212b)" style="fill: #2ca02c"/> - +" clip-path="url(#p290a25212b)" style="fill: #2ca02c"/> - +" clip-path="url(#p290a25212b)" style="fill: #2ca02c"/> - +" clip-path="url(#p290a25212b)" style="fill: #2ca02c"/> - +" clip-path="url(#p290a25212b)" style="fill: #d62728"/> - +" clip-path="url(#p290a25212b)" style="fill: #d62728"/> - +" clip-path="url(#p290a25212b)" style="fill: #d62728"/> - +" clip-path="url(#p290a25212b)" style="fill: #d62728"/> - +" clip-path="url(#p290a25212b)" style="fill: #9467bd"/> - +" clip-path="url(#p290a25212b)" style="fill: #9467bd"/> - +" clip-path="url(#p290a25212b)" style="fill: #9467bd"/> - +" clip-path="url(#p290a25212b)" style="fill: #9467bd"/> - - + - + - + - + - + - + - + - + - - + - - + + - - + - - + + - - + - - + + + + + - - + - + - - + + - - - + - + - - + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - + - - + + - - + - + - - + + - - + - + - + - @@ -1232,18 +1306,18 @@ L 450 26.88 " style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/> - - - + - - + - - + @@ -1411,16 +1485,16 @@ z - - + @@ -1430,16 +1504,16 @@ z - - + - - + - - + @@ -1541,8 +1615,8 @@ z - - + + diff --git a/benchmarks/embedded_compare/data/bench_summary.json b/benchmarks/embedded_compare/data/bench_summary.json index 004a46e..41172b5 100644 --- a/benchmarks/embedded_compare/data/bench_summary.json +++ b/benchmarks/embedded_compare/data/bench_summary.json @@ -1,16 +1,16 @@ { "engines": { "DecentDB": { - "commit_p95_ms": 3.01407, - "insert_rows_per_sec": 1109262.340543539, - "join_p95_ms": 0.366649, - "read_p95_ms": 0.001212 + "commit_p95_ms": 3.031444, + "insert_rows_per_sec": 1025956.704627065, + "join_p95_ms": 0.404229, + "read_p95_ms": 0.001894 }, "DuckDB": { - "commit_p95_ms": 11.938741, - "insert_rows_per_sec": 10514.98174599169, - "join_p95_ms": 1.805541, - "read_p95_ms": 0.204874 + "commit_p95_ms": 13.042862, + "insert_rows_per_sec": 9165.197180052131, + "join_p95_ms": 1.539912, + "read_p95_ms": 0.208532 }, "H2": { "read_p95_ms": 0.01785111746 @@ -20,17 +20,17 @@ "read_p95_ms": 0.018636946209999998 }, "SQLite": { - "commit_p95_ms": 3.017577, - "insert_rows_per_sec": 1110987.668036885, - "join_p95_ms": 0.39961, - "read_p95_ms": 0.002304 + "commit_p95_ms": 3.00779, + "insert_rows_per_sec": 1010917.913465427, + "join_p95_ms": 0.519636, + "read_p95_ms": 0.002595 } }, "metadata": { "durability_profile": "safe", "machine": "batman (AMD Ryzen 9 3900X 12-Core Processor)", "notes": "Generated from raw benchmark outputs in benchmarks/embedded_compare/raw/sample; merged extra engines from benchmarks/python_embedded_compare/out/results_merged.json", - "run_id": "20260221_202357", + "run_id": "20260222_213217", "units": { "commit_p95_ms": "ms (lower is better)", "insert_rows_per_sec": "rows/sec (higher is better)", diff --git a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBCommand.cs b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBCommand.cs index b485f00..973fd63 100644 --- a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBCommand.cs +++ b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBCommand.cs @@ -139,9 +139,33 @@ public override void Cancel() public override int ExecuteNonQuery() { - using var reader = ExecuteDbDataReader(CommandBehavior.Default); - while (reader.Read()) { } - return reader.RecordsAffected; + var statements = SqlStatementSplitter.Split(_commandText); + if (statements.Count <= 1) + { + using var reader = ExecuteDbDataReader(CommandBehavior.Default); + while (reader.Read()) { } + return reader.RecordsAffected; + } + + // Multi-statement: execute each individually, sum affected rows + var totalRows = 0; + var savedText = _commandText; + try + { + foreach (var stmt in statements) + { + _commandText = stmt; + using var reader = ExecuteDbDataReader(CommandBehavior.Default); + while (reader.Read()) { } + if (reader.RecordsAffected > 0) + totalRows += reader.RecordsAffected; + } + } + finally + { + _commandText = savedText; + } + return totalRows; } public override object? ExecuteScalar() @@ -156,11 +180,13 @@ public override int ExecuteNonQuery() public override Task ExecuteNonQueryAsync(CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(ExecuteNonQuery()); } public override Task ExecuteScalarAsync(CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(ExecuteScalar()); } @@ -173,6 +199,8 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) var db = _connection.GetNativeDb(); var (sql, paramMap) = SqlParameterRewriter.Rewrite(_commandText, _parameters); + SqlParameterRewriter.ClampOffsetParameters(sql, paramMap); + sql = SqlParameterRewriter.StripUpdateDeleteAlias(sql); var observation = _connection.TryStartSqlObservation(sql, SnapshotParameters(paramMap)); @@ -243,6 +271,7 @@ private static IReadOnlyList SnapshotParameters(Dictionary ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(ExecuteDbDataReader(behavior)); } diff --git a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBConnection.cs b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBConnection.cs index d86b426..b8ebdbd 100644 --- a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBConnection.cs +++ b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBConnection.cs @@ -6,6 +6,8 @@ using System.IO; using System.Collections.Generic; using System.Text; +using System.Threading; +using System.Threading.Tasks; using DecentDB.Native; namespace DecentDB.AdoNet @@ -144,6 +146,13 @@ public override void Open() } } + public override Task OpenAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + Open(); + return Task.CompletedTask; + } + protected override DbCommand CreateDbCommand() { return new DecentDBCommand(this); diff --git a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBDataReader.cs b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBDataReader.cs index f5081ad..a832925 100644 --- a/bindings/dotnet/src/DecentDB.AdoNet/DecentDBDataReader.cs +++ b/bindings/dotnet/src/DecentDB.AdoNet/DecentDBDataReader.cs @@ -185,7 +185,7 @@ public override T GetFieldValue(int ordinal) else if (nonNullableType == typeof(decimal)) { var type = _statement.ColumnType(ordinal); - if (type == 12) + if (type == 12) { boxed = _statement.GetDecimal(ordinal); } diff --git a/bindings/dotnet/src/DecentDB.AdoNet/SqlParameterRewriter.cs b/bindings/dotnet/src/DecentDB.AdoNet/SqlParameterRewriter.cs index b85a57b..e358157 100644 --- a/bindings/dotnet/src/DecentDB.AdoNet/SqlParameterRewriter.cs +++ b/bindings/dotnet/src/DecentDB.AdoNet/SqlParameterRewriter.cs @@ -205,7 +205,7 @@ int AllocateIndex() } } - if (char.IsLetter(sql[j]) || sql[j] == '_') + if (char.IsLetterOrDigit(sql[j]) || sql[j] == '_') { while (j < sql.Length && (char.IsLetterOrDigit(sql[j]) || sql[j] == '_')) { @@ -306,5 +306,129 @@ private static bool TryResolveIndexed(Dictionary parameters return null; } + + /// + /// Clamps OFFSET parameter values to 0 when negative. + /// EF Core may generate negative OFFSET values from untrusted page numbers. + /// + public static void ClampOffsetParameters(string sql, Dictionary paramMap) + { + var searchFrom = 0; + while (true) + { + var idx = sql.IndexOf("OFFSET", searchFrom, StringComparison.OrdinalIgnoreCase); + if (idx < 0) break; + + var afterOffset = idx + 6; + while (afterOffset < sql.Length && sql[afterOffset] == ' ') + afterOffset++; + + if (afterOffset < sql.Length && sql[afterOffset] == '$') + { + var numStart = afterOffset + 1; + var numEnd = numStart; + while (numEnd < sql.Length && char.IsDigit(sql[numEnd])) + numEnd++; + + if (numEnd > numStart && + int.TryParse(sql.Substring(numStart, numEnd - numStart), out var paramIndex) && + paramMap.TryGetValue(paramIndex, out var param) && + param.Value is IConvertible conv) + { + try + { + var val = conv.ToInt64(null); + if (val < 0) + param.Value = 0L; + } + catch { /* not numeric โ€” leave as-is */ } + } + } + + searchFrom = afterOffset; + } + } + + /// + /// Strips table aliases from UPDATE and DELETE statements that DecentDB core doesn't support. + /// Transforms: UPDATE "Table" AS "t" SET ... WHERE "t"."Col" = ... + /// Into: UPDATE "Table" SET ... WHERE "Table"."Col" = ... + /// Same for DELETE FROM "Table" AS "t" ... + /// + public static string StripUpdateDeleteAlias(string sql) + { + if (sql is null) return sql!; + + var trimmed = sql.TrimStart(); + bool isUpdate = trimmed.StartsWith("UPDATE ", StringComparison.OrdinalIgnoreCase); + bool isDelete = trimmed.StartsWith("DELETE ", StringComparison.OrdinalIgnoreCase); + if (!isUpdate && !isDelete) return sql; + + // Find the table name (quoted or unquoted) and the AS alias + // Pattern: UPDATE "TableName" AS "alias" or DELETE FROM "TableName" AS "alias" + int tableStart; + if (isUpdate) + { + tableStart = trimmed.IndexOf("UPDATE ", StringComparison.OrdinalIgnoreCase) + 7; + } + else + { + tableStart = trimmed.IndexOf("FROM ", StringComparison.OrdinalIgnoreCase); + if (tableStart < 0) return sql; + tableStart += 5; + } + + while (tableStart < trimmed.Length && trimmed[tableStart] == ' ') tableStart++; + + // Extract table name + string tableName; + int tableEnd; + if (tableStart < trimmed.Length && trimmed[tableStart] == '"') + { + var closeQuote = trimmed.IndexOf('"', tableStart + 1); + if (closeQuote < 0) return sql; + tableName = trimmed.Substring(tableStart, closeQuote - tableStart + 1); + tableEnd = closeQuote + 1; + } + else + { + tableEnd = tableStart; + while (tableEnd < trimmed.Length && !char.IsWhiteSpace(trimmed[tableEnd])) tableEnd++; + tableName = trimmed.Substring(tableStart, tableEnd - tableStart); + } + + // Look for AS "alias" after table name + var afterTable = tableEnd; + while (afterTable < trimmed.Length && trimmed[afterTable] == ' ') afterTable++; + + if (afterTable + 2 >= trimmed.Length) return sql; + if (!trimmed.Substring(afterTable, 2).Equals("AS", StringComparison.OrdinalIgnoreCase)) return sql; + if (afterTable + 2 < trimmed.Length && char.IsLetterOrDigit(trimmed[afterTable + 2])) return sql; + + var aliasStart = afterTable + 2; + while (aliasStart < trimmed.Length && trimmed[aliasStart] == ' ') aliasStart++; + + // Extract alias + string alias; + int aliasEnd; + if (aliasStart < trimmed.Length && trimmed[aliasStart] == '"') + { + var closeQuote = trimmed.IndexOf('"', aliasStart + 1); + if (closeQuote < 0) return sql; + alias = trimmed.Substring(aliasStart, closeQuote - aliasStart + 1); + aliasEnd = closeQuote + 1; + } + else + { + aliasEnd = aliasStart; + while (aliasEnd < trimmed.Length && !char.IsWhiteSpace(aliasEnd < trimmed.Length ? trimmed[aliasEnd] : ' ')) aliasEnd++; + alias = trimmed.Substring(aliasStart, aliasEnd - aliasStart); + } + + // Remove "AS alias" and replace "alias". references with "TableName". + var result = trimmed.Substring(0, tableEnd) + trimmed.Substring(aliasEnd); + result = result.Replace(alias + ".", tableName + "."); + return result; + } } } diff --git a/bindings/dotnet/src/DecentDB.AdoNet/SqlStatementSplitter.cs b/bindings/dotnet/src/DecentDB.AdoNet/SqlStatementSplitter.cs new file mode 100644 index 0000000..9d3e8e4 --- /dev/null +++ b/bindings/dotnet/src/DecentDB.AdoNet/SqlStatementSplitter.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; + +namespace DecentDB.AdoNet +{ + /// + /// Splits a multi-statement SQL string into individual statements. + /// Handles single-quoted strings, double-quoted identifiers, and hex literals (X'...'). + /// + public static class SqlStatementSplitter + { + public static List Split(string sql) + { + if (string.IsNullOrWhiteSpace(sql)) + return new List(); + + var statements = new List(); + var i = 0; + var stmtStart = 0; + + while (i < sql.Length) + { + var ch = sql[i]; + + if (ch == '\'') + { + i = SkipQuotedString(sql, i, '\''); + } + else if (ch == '"') + { + i = SkipQuotedString(sql, i, '"'); + } + else if (ch == '-' && i + 1 < sql.Length && sql[i + 1] == '-') + { + // Line comment: skip to end of line + i = sql.IndexOf('\n', i); + if (i < 0) i = sql.Length; + else i++; + } + else if (ch == '/' && i + 1 < sql.Length && sql[i + 1] == '*') + { + // Block comment: skip to */ + var end = sql.IndexOf("*/", i + 2, StringComparison.Ordinal); + i = end < 0 ? sql.Length : end + 2; + } + else if (ch == ';') + { + var stmt = sql.Substring(stmtStart, i - stmtStart).Trim(); + if (stmt.Length > 0) + statements.Add(stmt); + stmtStart = i + 1; + i++; + } + else + { + i++; + } + } + + // Remaining text after last semicolon + if (stmtStart < sql.Length) + { + var stmt = sql.Substring(stmtStart).Trim(); + if (stmt.Length > 0) + statements.Add(stmt); + } + + return statements; + } + + private static int SkipQuotedString(string sql, int start, char quoteChar) + { + var i = start + 1; + while (i < sql.Length) + { + if (sql[i] == quoteChar) + { + // Check for escaped quote (doubled quote char) + if (i + 1 < sql.Length && sql[i + 1] == quoteChar) + { + i += 2; + continue; + } + return i + 1; + } + i++; + } + return sql.Length; + } + } +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/DecentDBNodaTimeOptionsExtension.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/DecentDBNodaTimeOptionsExtension.cs index c2b91ae..427ce1b 100644 --- a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/DecentDBNodaTimeOptionsExtension.cs +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/DecentDBNodaTimeOptionsExtension.cs @@ -1,4 +1,6 @@ +using DecentDB.EntityFrameworkCore.NodaTime.Query.Internal; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -13,7 +15,10 @@ public DbContextOptionsExtensionInfo Info => _info ??= new ExtensionInfo(this); public void ApplyServices(IServiceCollection services) - => services.Replace(ServiceDescriptor.Singleton()); + { + services.Replace(ServiceDescriptor.Singleton()); + services.AddScoped(); + } public void Validate(IDbContextOptions options) { @@ -33,7 +38,7 @@ public override string LogFragment => "using NodaTime "; public override int GetServiceProviderHashCode() - => 1; + => typeof(DecentDBNodaTimeMemberTranslatorPlugin).GetHashCode(); public override void PopulateDebugInfo(IDictionary debugInfo) => debugInfo["DecentDB:NodaTime"] = "1"; diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslator.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslator.cs new file mode 100644 index 0000000..af9e434 --- /dev/null +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslator.cs @@ -0,0 +1,134 @@ +using System.Reflection; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using NodaTime; + +namespace DecentDB.EntityFrameworkCore.NodaTime.Query.Internal; + +/// +/// Translates NodaTime member accesses (e.g. LocalDate.Year) into SQL expressions. +/// LocalDate is stored as epoch days (days since 1970-01-01). +/// Uses the Hinnant civil calendar algorithm for correct date component extraction. +/// +public sealed class DecentDBNodaTimeMemberTranslator : IMemberTranslator +{ + private static readonly MemberInfo LocalDateYear = typeof(LocalDate).GetProperty(nameof(LocalDate.Year))!; + private static readonly MemberInfo LocalDateMonth = typeof(LocalDate).GetProperty(nameof(LocalDate.Month))!; + private static readonly MemberInfo LocalDateDay = typeof(LocalDate).GetProperty(nameof(LocalDate.Day))!; + private static readonly MemberInfo LocalDateDayOfYear = typeof(LocalDate).GetProperty(nameof(LocalDate.DayOfYear))!; + + private readonly ISqlExpressionFactory _sql; + + public DecentDBNodaTimeMemberTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sql = sqlExpressionFactory; + } + + public SqlExpression? Translate( + SqlExpression? instance, + MemberInfo member, + Type returnType, + IDiagnosticsLogger logger) + { + if (instance is null) + return null; + + if (member.DeclaringType != typeof(LocalDate)) + return null; + + if (member.Equals(LocalDateYear)) + return ExtractYear(instance); + if (member.Equals(LocalDateMonth)) + return ExtractMonth(instance); + if (member.Equals(LocalDateDay)) + return ExtractDay(instance); + if (member.Equals(LocalDateDayOfYear)) + return ExtractDayOfYear(instance); + + return null; + } + + // Hinnant civil_from_days algorithm โ€” pure integer arithmetic + // Input: d = epoch days (days since 1970-01-01) + // shifted = d + 719468 (days since 0000-03-01) + // era = shifted / 146097 (400-year era) + // doe = shifted - era * 146097 (day of era) + // yoe = (doe - doe/1460 + doe/36524 - doe/146096) / 365 (year of era) + // doy = doe - (365*yoe + yoe/4 - yoe/100) (day of year, March-based) + // mp = (5*doy + 2) / 153 (month period) + // m = mp < 10 ? mp + 3 : mp - 9 (month 1-12) + // year = yoe + era*400 + (m <= 2 ? 1 : 0) + + private SqlExpression ExtractYear(SqlExpression epochDays) + { + var d = EpochDaysAsLong(epochDays); + var (yoe, era, _, mp) = ComputeCivilComponents(d); + var m = MonthFromMp(mp); + var monthAdj = _sql.Case( + [new CaseWhenClause(_sql.LessThanOrEqual(m, Const(2)), Const(1))], + Const(0)); + return Add(Add(yoe, Mul(era, Const(400))), monthAdj); + } + + private SqlExpression ExtractMonth(SqlExpression epochDays) + { + var d = EpochDaysAsLong(epochDays); + var (_, _, _, mp) = ComputeCivilComponents(d); + return MonthFromMp(mp); + } + + private SqlExpression ExtractDay(SqlExpression epochDays) + { + var d = EpochDaysAsLong(epochDays); + var (_, _, doy, mp) = ComputeCivilComponents(d); + return Add(Sub(doy, Div(Add(Mul(Const(153), mp), Const(2)), Const(5))), Const(1)); + } + + private SqlExpression ExtractDayOfYear(SqlExpression epochDays) + { + var d = EpochDaysAsLong(epochDays); + var (_, _, doy, _) = ComputeCivilComponents(d); + return Add(doy, Const(1)); + } + + /// + /// Convert the NodaTime-typed column to long for integer arithmetic. + /// Without this, SqlBinaryExpression inherits LocalDate as the CLR type, + /// causing GroupBy translation to fail with "No coercion operator". + /// + private SqlExpression EpochDaysAsLong(SqlExpression instance) + => _sql.Convert(instance, typeof(long)); + + private (SqlExpression yoe, SqlExpression era, SqlExpression doy, SqlExpression mp) ComputeCivilComponents(SqlExpression epochDays) + { + var shifted = Add(epochDays, Const(719468)); + var era = Div(shifted, Const(146097)); + var doe = Sub(shifted, Mul(era, Const(146097))); + var yoe = Div(Sub(Sub(Add(doe, Div(doe, Const(36524))), Div(doe, Const(1460))), Div(doe, Const(146096))), Const(365)); + var doy = Sub(doe, Add(Sub(Mul(Const(365), yoe), Div(yoe, Const(100))), Div(yoe, Const(4)))); + var mp = Div(Add(Mul(Const(5), doy), Const(2)), Const(153)); + return (yoe, era, doy, mp); + } + + private SqlExpression MonthFromMp(SqlExpression mp) + => _sql.Case( + [new CaseWhenClause(_sql.LessThan(mp, Const(10)), Add(mp, Const(3)))], + Sub(mp, Const(9))); + + private SqlExpression Const(int value) + => _sql.Constant(value, typeof(int)); + + private SqlExpression Add(SqlExpression left, SqlExpression right) + => _sql.Add(left, right); + + private SqlExpression Sub(SqlExpression left, SqlExpression right) + => _sql.Subtract(left, right); + + private SqlExpression Mul(SqlExpression left, SqlExpression right) + => _sql.Multiply(left, right); + + private SqlExpression Div(SqlExpression left, SqlExpression right) + => _sql.Divide(left, right); +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslatorPlugin.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslatorPlugin.cs new file mode 100644 index 0000000..1a671d8 --- /dev/null +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Query/Internal/DecentDBNodaTimeMemberTranslatorPlugin.cs @@ -0,0 +1,13 @@ +using Microsoft.EntityFrameworkCore.Query; + +namespace DecentDB.EntityFrameworkCore.NodaTime.Query.Internal; + +public sealed class DecentDBNodaTimeMemberTranslatorPlugin : IMemberTranslatorPlugin +{ + public DecentDBNodaTimeMemberTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory) + { + Translators = [new DecentDBNodaTimeMemberTranslator(sqlExpressionFactory)]; + } + + public IEnumerable Translators { get; } +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Storage/Internal/DecentDBNodaTimeTypeMappingSource.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Storage/Internal/DecentDBNodaTimeTypeMappingSource.cs index aedce75..ac50a6d 100644 --- a/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Storage/Internal/DecentDBNodaTimeTypeMappingSource.cs +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore.NodaTime/Storage/Internal/DecentDBNodaTimeTypeMappingSource.cs @@ -31,8 +31,8 @@ public DecentDBNodaTimeTypeMappingSource( var dateTimeMapping = (RelationalTypeMapping)longMapping.WithComposedConverter( new ValueConverter( - value => new DateTimeOffset(value.Kind == DateTimeKind.Utc ? value : value.ToUniversalTime(), TimeSpan.Zero).ToUnixTimeMilliseconds(), - value => DateTimeOffset.FromUnixTimeMilliseconds(value).UtcDateTime), + value => new DateTimeOffset(value.Kind == DateTimeKind.Utc ? value : value.ToUniversalTime(), TimeSpan.Zero).UtcTicks, + value => new DateTime(value, DateTimeKind.Utc)), comparer: null, keyComparer: null, elementMapping: null, @@ -40,8 +40,8 @@ public DecentDBNodaTimeTypeMappingSource( var dateTimeOffsetMapping = (RelationalTypeMapping)longMapping.WithComposedConverter( new ValueConverter( - value => value.ToUniversalTime().ToUnixTimeMilliseconds(), - value => DateTimeOffset.FromUnixTimeMilliseconds(value)), + value => value.UtcTicks, + value => new DateTimeOffset(value, TimeSpan.Zero)), comparer: null, keyComparer: null, elementMapping: null, @@ -85,8 +85,8 @@ public DecentDBNodaTimeTypeMappingSource( var instantMapping = (RelationalTypeMapping)longMapping.WithComposedConverter( new ValueConverter( - value => value.ToUnixTimeMilliseconds(), - value => Instant.FromUnixTimeMilliseconds(value)), + value => value.ToUnixTimeTicks(), + value => Instant.FromUnixTimeTicks(value)), comparer: null, keyComparer: null, elementMapping: null, @@ -103,8 +103,8 @@ public DecentDBNodaTimeTypeMappingSource( var localDateTimeMapping = (RelationalTypeMapping)longMapping.WithComposedConverter( new ValueConverter( - value => value.InZoneLeniently(DateTimeZone.Utc).ToInstant().ToUnixTimeMilliseconds(), - value => Instant.FromUnixTimeMilliseconds(value).InUtc().LocalDateTime), + value => value.InZoneLeniently(DateTimeZone.Utc).ToInstant().ToUnixTimeTicks(), + value => Instant.FromUnixTimeTicks(value).InUtc().LocalDateTime), comparer: null, keyComparer: null, elementMapping: null, diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Extensions/DecentDBServiceCollectionExtensions.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Extensions/DecentDBServiceCollectionExtensions.cs index bf25922..0429d60 100644 --- a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Extensions/DecentDBServiceCollectionExtensions.cs +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Extensions/DecentDBServiceCollectionExtensions.cs @@ -36,6 +36,8 @@ public static IServiceCollection AddEntityFrameworkDecentDB(this IServiceCollect builder.TryAdd(); builder.TryAdd(); builder.TryAdd(); + builder.TryAdd(); builder.TryAddCoreServices(); return serviceCollection; diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitor.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitor.cs new file mode 100644 index 0000000..7cbfafe --- /dev/null +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitor.cs @@ -0,0 +1,295 @@ +using DecentDB.EntityFrameworkCore.Query.Internal.SqlExpressions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Storage; + +namespace DecentDB.EntityFrameworkCore.Query.Internal; + +// Cached nullability arrays (Statics.TrueArrays is internal to EF Core) +file static class NullabilityArrays +{ + internal static readonly bool[] SingleTrue = [true]; +} + +/// +/// Translates primitive collections (e.g. string[] stored as JSON) to SQL +/// using json_each/json_array_length functions. +/// +public class DecentDBQueryableMethodTranslatingExpressionVisitor + : RelationalQueryableMethodTranslatingExpressionVisitor +{ + private readonly IRelationalTypeMappingSource _typeMappingSource; + private readonly ISqlExpressionFactory _sqlExpressionFactory; + private readonly SqlAliasManager _sqlAliasManager; + + public DecentDBQueryableMethodTranslatingExpressionVisitor( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies, + RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies, + RelationalQueryCompilationContext queryCompilationContext) + : base(dependencies, relationalDependencies, queryCompilationContext) + { + _typeMappingSource = relationalDependencies.TypeMappingSource; + _sqlExpressionFactory = relationalDependencies.SqlExpressionFactory; + _sqlAliasManager = queryCompilationContext.SqlAliasManager; + } + + protected DecentDBQueryableMethodTranslatingExpressionVisitor( + DecentDBQueryableMethodTranslatingExpressionVisitor parentVisitor) + : base(parentVisitor) + { + _typeMappingSource = parentVisitor._typeMappingSource; + _sqlExpressionFactory = parentVisitor._sqlExpressionFactory; + _sqlAliasManager = parentVisitor._sqlAliasManager; + } + + protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor() + => new DecentDBQueryableMethodTranslatingExpressionVisitor(this); + + /// + /// Optimizes .Any() on a primitive collection to json_array_length() > 0 + /// instead of EXISTS (SELECT 1 FROM json_each(...)). + /// + protected override ShapedQueryExpression? TranslateAny( + ShapedQueryExpression source, + System.Linq.Expressions.LambdaExpression? predicate) + { + if (predicate is null + && source.QueryExpression is SelectExpression + { + Tables: [JsonEachExpression jsonEach], + Predicate: null, + GroupBy: [], + Having: null, + IsDistinct: false, + Limit: null, + Offset: null + }) + { + var translation = + _sqlExpressionFactory.GreaterThan( + _sqlExpressionFactory.Function( + "json_array_length", + [jsonEach.Json], + nullable: true, + argumentsPropagateNullability: NullabilityArrays.SingleTrue, + typeof(int)), + _sqlExpressionFactory.Constant(0)); + +#pragma warning disable EF1001 + return source.UpdateQueryExpression(new SelectExpression(translation, _sqlAliasManager)); +#pragma warning restore EF1001 + } + + return base.TranslateAny(source, predicate); + } + + /// + /// Optimizes .Count() on a primitive collection to json_array_length() + /// instead of SELECT COUNT(*) FROM json_each(...). + /// + protected override ShapedQueryExpression? TranslateCount( + ShapedQueryExpression source, + System.Linq.Expressions.LambdaExpression? predicate) + { + if (predicate is null + && source.QueryExpression is SelectExpression + { + Tables: [JsonEachExpression jsonEach], + Predicate: null, + GroupBy: [], + Having: null, + IsDistinct: false, + Limit: null, + Offset: null + }) + { + var translation = _sqlExpressionFactory.Function( + "json_array_length", + [jsonEach.Json], + nullable: true, + argumentsPropagateNullability: NullabilityArrays.SingleTrue, + typeof(int)); + +#pragma warning disable EF1001 + return source.UpdateQueryExpression(new SelectExpression(translation, _sqlAliasManager)); +#pragma warning restore EF1001 + } + + return base.TranslateCount(source, predicate); + } + + /// + /// Optimizes array.Contains(value) on a string[] primitive collection to + /// column LIKE '%"' || value || '"%' instead of EXISTS (SELECT FROM json_each(...)). + /// + protected override ShapedQueryExpression? TranslateContains( + ShapedQueryExpression source, + System.Linq.Expressions.Expression item) + { + if (source.QueryExpression is SelectExpression + { + Tables: [JsonEachExpression jsonEach], + Predicate: null, + GroupBy: [], + Having: null, + IsDistinct: false, + Limit: null, + Offset: null + } + && TranslateExpression(item) is SqlExpression translatedItem + && translatedItem.Type == typeof(string)) + { + var stringMapping = (RelationalTypeMapping)_typeMappingSource.FindMapping(typeof(string))!; + + // Build: '%"' || @value || '"%' + var pattern = _sqlExpressionFactory.Add( + _sqlExpressionFactory.Add( + _sqlExpressionFactory.Constant("%\"", stringMapping), + translatedItem, + stringMapping), + _sqlExpressionFactory.Constant("\"%", stringMapping), + stringMapping); + + var translation = _sqlExpressionFactory.Like(jsonEach.Json, pattern); + +#pragma warning disable EF1001 + return source.UpdateQueryExpression(new SelectExpression(translation, _sqlAliasManager)); +#pragma warning restore EF1001 + } + + return base.TranslateContains(source, item); + } + + /// + /// Optimizes array[index] on a primitive collection to json_extract(column, '$[index]') + /// instead of a json_each() subquery with LIMIT/OFFSET. + /// + protected override ShapedQueryExpression? TranslateElementAtOrDefault( + ShapedQueryExpression source, + System.Linq.Expressions.Expression index, + bool returnDefault) + { + if (!returnDefault + && source.QueryExpression is SelectExpression + { + Tables: [JsonEachExpression jsonEach], + Predicate: null, + GroupBy: [], + Having: null, + IsDistinct: false, + Limit: null, + Offset: null + } selectExpression + && TranslateExpression(index) is SqlConstantExpression { Value: int indexValue }) + { + var shaperExpression = source.ShaperExpression; + if (shaperExpression is System.Linq.Expressions.UnaryExpression + { + NodeType: System.Linq.Expressions.ExpressionType.Convert + } unaryExpression + && unaryExpression.Operand.Type.IsValueType + && Nullable.GetUnderlyingType(unaryExpression.Operand.Type) is not null) + { + shaperExpression = unaryExpression.Operand; + } + + if (shaperExpression is ProjectionBindingExpression projectionBindingExpression + && selectExpression.GetProjection(projectionBindingExpression) is ColumnExpression projectionColumn) + { + var translation = _sqlExpressionFactory.Function( + "json_extract", + [jsonEach.Json, _sqlExpressionFactory.Constant($"$[{indexValue}]")], + nullable: true, + argumentsPropagateNullability: [true, false], + projectionColumn.Type, + projectionColumn.TypeMapping); + +#pragma warning disable EF1001 + return source.UpdateQueryExpression(new SelectExpression(translation, _sqlAliasManager)); +#pragma warning restore EF1001 + } + } + + return base.TranslateElementAtOrDefault(source, index, returnDefault); + } + + /// + /// Translates a primitive collection (e.g. string[] column stored as JSON) + /// into a queryable json_each() table expression. + /// + protected override ShapedQueryExpression? TranslatePrimitiveCollection( + SqlExpression sqlExpression, + Microsoft.EntityFrameworkCore.Metadata.IProperty? property, + string tableAlias) + { + var elementTypeMapping = (RelationalTypeMapping?)sqlExpression.TypeMapping?.ElementTypeMapping; + var elementClrType = GetElementClrType(sqlExpression.Type); + var jsonEachExpression = new JsonEachExpression(tableAlias, sqlExpression); + + var isElementNullable = property?.GetElementType()!.IsNullable; + var keyColumnTypeMapping = _typeMappingSource.FindMapping(typeof(int))!; + var unwrappedElementType = Nullable.GetUnderlyingType(elementClrType) ?? elementClrType; + var isNullable = isElementNullable ?? (Nullable.GetUnderlyingType(elementClrType) is not null); + var nullableElementType = elementClrType.IsValueType && Nullable.GetUnderlyingType(elementClrType) is null + ? typeof(Nullable<>).MakeGenericType(elementClrType) + : elementClrType; + +#pragma warning disable EF1001 + var selectExpression = new SelectExpression( + [jsonEachExpression], + new ColumnExpression( + JsonEachExpression.ValueColumnName, + tableAlias, + unwrappedElementType, + elementTypeMapping, + isNullable), + identifier: + [ + (new ColumnExpression( + JsonEachExpression.KeyColumnName, + tableAlias, + typeof(int), + keyColumnTypeMapping, + nullable: false), + keyColumnTypeMapping.Comparer) + ], + _sqlAliasManager); +#pragma warning restore EF1001 + + selectExpression.AppendOrdering( + new OrderingExpression( + selectExpression.CreateColumnExpression( + jsonEachExpression, + JsonEachExpression.KeyColumnName, + typeof(int), + typeMapping: _typeMappingSource.FindMapping(typeof(int)), + columnNullable: false), + ascending: true)); + + System.Linq.Expressions.Expression shaperExpression = + new ProjectionBindingExpression(selectExpression, new ProjectionMember(), nullableElementType); + + if (elementClrType != shaperExpression.Type) + { + shaperExpression = System.Linq.Expressions.Expression.Convert(shaperExpression, elementClrType); + } + + return new ShapedQueryExpression(selectExpression, shaperExpression); + } + + private static Type GetElementClrType(Type collectionType) + { + if (collectionType.IsArray) + { + return collectionType.GetElementType()!; + } + + if (collectionType.IsGenericType) + { + return collectionType.GetGenericArguments()[0]; + } + + return collectionType; + } +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitorFactory.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitorFactory.cs new file mode 100644 index 0000000..3830f90 --- /dev/null +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/DecentDBQueryableMethodTranslatingExpressionVisitorFactory.cs @@ -0,0 +1,24 @@ +using Microsoft.EntityFrameworkCore.Query; + +namespace DecentDB.EntityFrameworkCore.Query.Internal; + +public sealed class DecentDBQueryableMethodTranslatingExpressionVisitorFactory + : IQueryableMethodTranslatingExpressionVisitorFactory +{ + private readonly QueryableMethodTranslatingExpressionVisitorDependencies _dependencies; + private readonly RelationalQueryableMethodTranslatingExpressionVisitorDependencies _relationalDependencies; + + public DecentDBQueryableMethodTranslatingExpressionVisitorFactory( + QueryableMethodTranslatingExpressionVisitorDependencies dependencies, + RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies) + { + _dependencies = dependencies; + _relationalDependencies = relationalDependencies; + } + + public QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) + => new DecentDBQueryableMethodTranslatingExpressionVisitor( + _dependencies, + _relationalDependencies, + (RelationalQueryCompilationContext)queryCompilationContext); +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/SqlExpressions/JsonEachExpression.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/SqlExpressions/JsonEachExpression.cs new file mode 100644 index 0000000..ae66bc2 --- /dev/null +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Query/Internal/SqlExpressions/JsonEachExpression.cs @@ -0,0 +1,69 @@ +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace DecentDB.EntityFrameworkCore.Query.Internal.SqlExpressions; + +/// +/// Represents a json_each() table-valued function call, used as an intermediate +/// expression for primitive collection translation. In most cases this gets optimized +/// away (e.g. .Any() โ†’ json_array_length() > 0) before SQL generation. +/// +public sealed class JsonEachExpression( + string alias, + SqlExpression json) + : TableValuedFunctionExpression(alias, "json_each", schema: null, builtIn: true, [json]) +{ + private static ConstructorInfo? _quotingConstructor; + + public const string KeyColumnName = "key"; + public const string ValueColumnName = "value"; + + public SqlExpression Json { get; } = json; + + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var visitedJson = (SqlExpression)visitor.Visit(Json); + return visitedJson == Json ? this : new JsonEachExpression(Alias, visitedJson); + } + + public JsonEachExpression Update(SqlExpression jsonExpression) + => jsonExpression == Json ? this : new JsonEachExpression(Alias, jsonExpression); + + public override TableExpressionBase Clone(string? alias, ExpressionVisitor cloningExpressionVisitor) + { + var newJson = (SqlExpression)cloningExpressionVisitor.Visit(Json); + var clone = new JsonEachExpression(alias!, newJson); + foreach (var annotation in GetAnnotations()) + { + clone.AddAnnotation(annotation.Name, annotation.Value); + } + return clone; + } + + public override JsonEachExpression WithAlias(string newAlias) + => new(newAlias, Json); + + public override Expression Quote() + => Expression.New( + _quotingConstructor ??= typeof(JsonEachExpression).GetConstructor([typeof(string), typeof(SqlExpression)])!, + Expression.Constant(Alias, typeof(string)), + Json.Quote()); + + protected override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append("json_each("); + expressionPrinter.Visit(Json); + expressionPrinter.Append(")"); + PrintAnnotations(expressionPrinter); + expressionPrinter.Append(" AS "); + expressionPrinter.Append(Alias); + } + + public override bool Equals(object? obj) + => ReferenceEquals(this, obj) || (obj is JsonEachExpression other && base.Equals(other)); + + public override int GetHashCode() + => base.GetHashCode(); +} diff --git a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Storage/Internal/DecentDBDatabaseCreator.cs b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Storage/Internal/DecentDBDatabaseCreator.cs index 61cfb5c..71f2589 100644 --- a/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Storage/Internal/DecentDBDatabaseCreator.cs +++ b/bindings/dotnet/src/DecentDB.EntityFrameworkCore/Storage/Internal/DecentDBDatabaseCreator.cs @@ -43,10 +43,41 @@ public override Task DeleteAsync(CancellationToken cancellationToken = default) } public override bool HasTables() - => Exists(); + { + if (!Exists()) + return false; + + var dbConnection = Dependencies.Connection.DbConnection; + var wasOpen = dbConnection.State == ConnectionState.Open; + try + { + if (!wasOpen) + dbConnection.Open(); + + if (dbConnection is DecentDBConnection ddbConn) + { + var json = ddbConn.ListTablesJson(); + return json != "[]" && json.Length > 2; + } + + return false; + } + catch + { + return false; + } + finally + { + if (!wasOpen) + dbConnection.Close(); + } + } public override Task HasTablesAsync(CancellationToken cancellationToken = default) - => Task.FromResult(HasTables()); + { + cancellationToken.ThrowIfCancellationRequested(); + return Task.FromResult(HasTables()); + } private void OpenAndCloseConnection() { diff --git a/bindings/dotnet/src/DecentDB.Native/DecentDB.cs b/bindings/dotnet/src/DecentDB.Native/DecentDB.cs index 821aec9..6a756f3 100644 --- a/bindings/dotnet/src/DecentDB.Native/DecentDB.cs +++ b/bindings/dotnet/src/DecentDB.Native/DecentDB.cs @@ -244,7 +244,7 @@ public PreparedStatement BindGuid(int index1Based, Guid value) var bytes = stackalloc byte[16]; if (!value.TryWriteBytes(new Span(bytes, 16))) throw new InvalidOperationException("Failed to write Guid bytes"); - + var res = DecentDBNativeUnsafe.decentdb_bind_blob(Handle, index1Based, bytes, 16); if (res < 0) { @@ -258,7 +258,7 @@ public PreparedStatement BindDecimal(int index1Based, decimal value) { // DecentDB currently supports DECIMAL backed by INT64 (approx 18 digits). // C# decimal is 96-bit integer + scale. We must check if it fits in 64-bit. - + Span bits = stackalloc int[4]; decimal.GetBits(value, bits); int low = bits[0]; @@ -268,22 +268,22 @@ public PreparedStatement BindDecimal(int index1Based, decimal value) int scale = (flags >> 16) & 0xFF; bool isNegative = (flags & 0x80000000) != 0; - if (high != 0) + if (high != 0) { - throw new OverflowException("Value is too large for DecentDB DECIMAL (must fit in 64-bit unscaled integer)"); + throw new OverflowException("Value is too large for DecentDB DECIMAL (must fit in 64-bit unscaled integer)"); } - + // Combine Mid and Low ulong unscaledU = ((ulong)(uint)mid << 32) | (uint)low; - + if (unscaledU > (ulong)long.MaxValue + (ulong)(isNegative ? 1 : 0)) { - throw new OverflowException("Value is too large for DecentDB DECIMAL (must fit in 64-bit unscaled integer)"); + throw new OverflowException("Value is too large for DecentDB DECIMAL (must fit in 64-bit unscaled integer)"); } - + long unscaled = (long)unscaledU; if (isNegative) unscaled = -unscaled; - + var res = DecentDBNativeUnsafe.decentdb_bind_decimal(Handle, index1Based, unscaled, scale); if (res < 0) { @@ -394,7 +394,7 @@ public bool GetBool(int col0Based) // We can just check != 0. return GetInt64(col0Based) != 0; } - + public Guid GetGuid(int col0Based) { unsafe @@ -402,12 +402,12 @@ public Guid GetGuid(int col0Based) // Try to get as blob first, as UUIDs are stored as blobs int len; var ptr = (byte*)DecentDBNative.decentdb_column_blob(Handle, col0Based, out len); - + if (ptr != null && len == 16) { return new Guid(new ReadOnlySpan(ptr, 16)); } - + // Fallback: check text if blob failed or length mismatch (e.g. legacy text UUIDs?) // Although ADR 0091 says text will no longer be accepted for ctUuid, // we might still have text columns that contain UUID strings. @@ -417,7 +417,7 @@ public Guid GetGuid(int col0Based) var s = new string((sbyte*)ptr, 0, len, System.Text.Encoding.UTF8); if (Guid.TryParse(s, out var g)) return g; } - + return Guid.Empty; } } @@ -439,11 +439,11 @@ public decimal GetDecimal(int col0Based) // decimal(int lo, int mid, int hi, bool isNegative, byte scale) bool isNegative = unscaled < 0; ulong u = isNegative ? (ulong)(-unscaled) : (ulong)unscaled; - + int lo = (int)(u & 0xFFFFFFFF); int mid = (int)(u >> 32); int hi = 0; - + return new decimal(lo, mid, hi, isNegative, (byte)scale); } diff --git a/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/DesignTimeToolingTests.cs b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/DesignTimeToolingTests.cs index f53cba9..206f3c8 100644 --- a/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/DesignTimeToolingTests.cs +++ b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/DesignTimeToolingTests.cs @@ -35,7 +35,7 @@ public void DotnetEf_MigrationsAdd_AndDatabaseUpdate_Work() enable - + all runtime; build; native; contentfiles; analyzers; buildtransitive @@ -148,9 +148,12 @@ private static (int ExitCode, string Output) Run(string workingDirectory, string RedirectStandardError = true }; + // Prevent MSBuild node reuse to avoid lock contention with the test host's MSBuild server. + psi.Environment["MSBUILDDISABLENODEREUSE"] = "1"; + psi.Environment["DOTNET_CLI_DO_NOT_USE_MSBUILD_SERVER"] = "1"; + using var process = Process.Start(psi) ?? throw new InvalidOperationException("Failed to start dotnet process."); - // Avoid deadlocks: read stdout/stderr concurrently while the process runs. var stdoutTask = process.StandardOutput.ReadToEndAsync(); var stderrTask = process.StandardError.ReadToEndAsync(); @@ -168,21 +171,35 @@ private static (int ExitCode, string Output) Run(string workingDirectory, string throw new TimeoutException($"dotnet {arguments} timed out after {DotnetCommandTimeout}."); } - Task.WaitAll(stdoutTask, stderrTask); - var output = stdoutTask.Result + stderrTask.Result; - return (process.ExitCode, output); + // Use a bounded wait for output to prevent hangs from orphaned child processes holding pipes. + if (!Task.WaitAll([stdoutTask, stderrTask], TimeSpan.FromSeconds(10))) + { + try + { + process.Kill(entireProcessTree: true); + } + catch + { + // Best-effort cleanup. + } + } + + var stdout = stdoutTask.IsCompleted ? stdoutTask.Result : ""; + var stderr = stderrTask.IsCompleted ? stderrTask.Result : ""; + return (process.ExitCode, stdout + stderr); } private static void StageNativeLibrary(string repoRoot, string outputDirectory) { Directory.CreateDirectory(outputDirectory); + // Prefer build/ output (canonical build dir) over repo root to avoid stale binaries. if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { CopyFirstExisting( - Path.Combine(repoRoot, "libdecentdb.so"), Path.Combine(repoRoot, "build", "libdecentdb.so"), Path.Combine(repoRoot, "build", "libc_api.so"), + Path.Combine(repoRoot, "libdecentdb.so"), destinationPath: Path.Combine(outputDirectory, "libdecentdb.so")); return; } @@ -190,9 +207,9 @@ private static void StageNativeLibrary(string repoRoot, string outputDirectory) if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { CopyFirstExisting( - Path.Combine(repoRoot, "libdecentdb.dylib"), Path.Combine(repoRoot, "build", "libdecentdb.dylib"), Path.Combine(repoRoot, "build", "libc_api.dylib"), + Path.Combine(repoRoot, "libdecentdb.dylib"), destinationPath: Path.Combine(outputDirectory, "libdecentdb.dylib")); return; } @@ -200,9 +217,9 @@ private static void StageNativeLibrary(string repoRoot, string outputDirectory) if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { CopyFirstExisting( - Path.Combine(repoRoot, "decentdb.dll"), Path.Combine(repoRoot, "build", "decentdb.dll"), Path.Combine(repoRoot, "build", "c_api.dll"), + Path.Combine(repoRoot, "decentdb.dll"), destinationPath: Path.Combine(outputDirectory, "decentdb.dll")); } } diff --git a/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/NodaTimeIntegrationTests.cs b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/NodaTimeIntegrationTests.cs index 95a18eb..c0c0a28 100644 --- a/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/NodaTimeIntegrationTests.cs +++ b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/NodaTimeIntegrationTests.cs @@ -46,6 +46,114 @@ public void UseNodaTime_RegistersTypeMappings_AndRoundTrips() Assert.Equal(row.LocalAt, loaded.LocalAt); } + [Fact] + public void UseNodaTime_PreservesTickPrecision_ForInstant() + { + EnsureSchema(); + + using var context = CreateContext(); + + // Instant with sub-millisecond precision (microseconds) + var preciseInstant = Instant.FromUnixTimeTicks(17095044690001234L); + var row = new NodaEvent + { + Name = "precision_test", + At = preciseInstant, + Day = new LocalDate(2026, 1, 2), + LocalAt = new LocalDateTime(2026, 1, 2, 3, 4, 5, 678).PlusNanoseconds(912300) + }; + + context.Events.Add(row); + context.SaveChanges(); + + var loaded = context.Events.Single(x => x.Name == "precision_test"); + Assert.Equal(preciseInstant, loaded.At); + Assert.Equal(row.LocalAt, loaded.LocalAt); + } + + [Fact] + public void UseNodaTime_TranslatesLocalDateYear_InGroupBy() + { + EnsureSchema(); + + using var context = CreateContext(); + context.Events.AddRange( + new NodaEvent { Name = "a", At = Instant.FromUtc(2024, 1, 1, 0, 0), Day = new LocalDate(2024, 3, 15), LocalAt = new LocalDateTime(2024, 3, 15, 0, 0, 0) }, + new NodaEvent { Name = "b", At = Instant.FromUtc(2024, 6, 1, 0, 0), Day = new LocalDate(2024, 6, 1), LocalAt = new LocalDateTime(2024, 6, 1, 0, 0, 0) }, + new NodaEvent { Name = "c", At = Instant.FromUtc(2025, 1, 1, 0, 0), Day = new LocalDate(2025, 1, 10), LocalAt = new LocalDateTime(2025, 1, 10, 0, 0, 0) }); + context.SaveChanges(); + + var byYear = context.Events + .GroupBy(e => e.Day.Year) + .Select(g => new { Year = g.Key, Count = g.Count() }) + .OrderBy(x => x.Year) + .ToList(); + + Assert.Equal(2, byYear.Count); + Assert.Equal(2024, byYear[0].Year); + Assert.Equal(2, byYear[0].Count); + Assert.Equal(2025, byYear[1].Year); + Assert.Equal(1, byYear[1].Count); + } + + [Fact] + public void UseNodaTime_TranslatesLocalDateMonthAndDay_InProjection() + { + EnsureSchema(); + + using var context = CreateContext(); + context.Events.Add(new NodaEvent + { + Name = "date_parts", + At = Instant.FromUtc(2026, 7, 23, 0, 0), + Day = new LocalDate(2026, 7, 23), + LocalAt = new LocalDateTime(2026, 7, 23, 0, 0, 0) + }); + context.SaveChanges(); + + var result = context.Events + .Where(e => e.Name == "date_parts") + .Select(e => new { e.Day.Year, e.Day.Month, e.Day.Day }) + .Single(); + + Assert.Equal(2026, result.Year); + Assert.Equal(7, result.Month); + Assert.Equal(23, result.Day); + } + + [Fact] + public void UseNodaTime_TranslatesLocalDateYear_ForHistoricAndEpochDates() + { + EnsureSchema(); + + using var context = CreateContext(); + context.Events.AddRange( + new NodaEvent { Name = "epoch", At = Instant.FromUtc(1970, 1, 1, 0, 0), Day = new LocalDate(1970, 1, 1), LocalAt = new LocalDateTime(1970, 1, 1, 0, 0, 0) }, + new NodaEvent { Name = "pre_epoch", At = Instant.FromUtc(1969, 12, 31, 0, 0), Day = new LocalDate(1969, 12, 31), LocalAt = new LocalDateTime(1969, 12, 31, 0, 0, 0) }, + new NodaEvent { Name = "leap", At = Instant.FromUtc(2000, 2, 29, 0, 0), Day = new LocalDate(2000, 2, 29), LocalAt = new LocalDateTime(2000, 2, 29, 0, 0, 0) }); + context.SaveChanges(); + + var results = context.Events + .OrderBy(e => e.Name) + .Select(e => new { e.Name, e.Day.Year, e.Day.Month, e.Day.Day }) + .ToList(); + + var epoch = results.Single(r => r.Name == "epoch"); + Assert.Equal(1970, epoch.Year); + Assert.Equal(1, epoch.Month); + Assert.Equal(1, epoch.Day); + + var preEpoch = results.Single(r => r.Name == "pre_epoch"); + Assert.Equal(1969, preEpoch.Year); + Assert.Equal(12, preEpoch.Month); + Assert.Equal(31, preEpoch.Day); + + var leap = results.Single(r => r.Name == "leap"); + Assert.Equal(2000, leap.Year); + Assert.Equal(2, leap.Month); + Assert.Equal(29, leap.Day); + } + private NodaDbContext CreateContext() { var optionsBuilder = new DbContextOptionsBuilder(); @@ -59,7 +167,7 @@ private void EnsureSchema() connection.Open(); using var command = connection.CreateCommand(); - command.CommandText = "CREATE TABLE ef_noda_events (id INTEGER PRIMARY KEY, name TEXT NOT NULL, at_ms INTEGER NOT NULL, day_num INTEGER NOT NULL, local_ms INTEGER NOT NULL)"; + command.CommandText = "CREATE TABLE ef_noda_events (id INTEGER PRIMARY KEY, name TEXT NOT NULL, at_ticks INTEGER NOT NULL, day_num INTEGER NOT NULL, local_ticks INTEGER NOT NULL)"; command.ExecuteNonQuery(); } @@ -91,9 +199,9 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) entity.HasKey(x => x.Id); entity.Property(x => x.Id).HasColumnName("id").ValueGeneratedOnAdd(); entity.Property(x => x.Name).HasColumnName("name"); - entity.Property(x => x.At).HasColumnName("at_ms"); + entity.Property(x => x.At).HasColumnName("at_ticks"); entity.Property(x => x.Day).HasColumnName("day_num"); - entity.Property(x => x.LocalAt).HasColumnName("local_ms"); + entity.Property(x => x.LocalAt).HasColumnName("local_ticks"); }); } } diff --git a/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/PrimitiveCollectionTests.cs b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/PrimitiveCollectionTests.cs new file mode 100644 index 0000000..14f8c7e --- /dev/null +++ b/bindings/dotnet/tests/DecentDB.EntityFrameworkCore.Tests/PrimitiveCollectionTests.cs @@ -0,0 +1,285 @@ +using DecentDB.AdoNet; +using DecentDB.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore; +using Xunit; + +namespace DecentDB.EntityFrameworkCore.Tests; + +/// +/// Tests for EF Core primitive collection translation (string[] stored as JSON). +/// Validates that json_array_length/json_extract are used for .Any(), .Count(), etc. +/// +public sealed class PrimitiveCollectionTests : IDisposable +{ + private readonly string _dbPath = Path.Combine(Path.GetTempPath(), $"test_ef_primcoll_{Guid.NewGuid():N}.ddb"); + + public void Dispose() + { + TryDelete(_dbPath); + TryDelete(_dbPath + "-wal"); + } + + [Fact] + public async Task StringArray_AnyWithoutPredicate_TranslatesToJsonArrayLength() + { + SeedData(); + await using var context = CreateContext(); + + var result = await context.Albums + .Where(a => a.Genres != null && a.Genres.Length > 0) + .Select(a => a.Name) + .ToListAsync(); + + Assert.Equal(2, result.Count); + Assert.Contains("Rock Album", result); + Assert.Contains("Jazz Album", result); + } + + [Fact] + public async Task StringArray_SelectArray_ReturnsJsonArray() + { + SeedData(); + await using var context = CreateContext(); + + var genres = await context.Albums + .Where(a => a.Genres != null && a.Genres.Length > 0) + .OrderBy(a => a.Name) + .Select(a => a.Genres) + .ToListAsync(); + + Assert.Equal(2, genres.Count); + Assert.NotNull(genres[0]); + Assert.Contains("Jazz", genres[0]!); + } + + [Fact] + public async Task StringArray_NullArray_ExcludedByLengthCheck() + { + SeedData(); + await using var context = CreateContext(); + + var allAlbums = await context.Albums.CountAsync(); + var withGenres = await context.Albums + .Where(a => a.Genres != null && a.Genres.Length > 0) + .CountAsync(); + + Assert.Equal(3, allAlbums); + Assert.Equal(2, withGenres); + } + + [Fact] + public async Task StringArray_EmptyArray_ExcludedByLengthCheck() + { + SeedEmptyArrayData(); + await using var context = CreateContext(); + + var withGenres = await context.Albums + .Where(a => a.Genres != null && a.Genres.Length > 0) + .CountAsync(); + + Assert.Equal(1, withGenres); + } + + [Fact] + public async Task StringArray_ElementAt_TranslatesToJsonExtract() + { + SeedData(); + await using var context = CreateContext(); + + var firstGenres = await context.Albums + .Where(a => a.Genres != null && a.Genres.Length > 0) + .OrderBy(a => a.Name) + .Select(a => a.Genres![0]) + .ToListAsync(); + + Assert.Equal(2, firstGenres.Count); + Assert.Equal("Jazz", firstGenres[0]); + Assert.Equal("Rock", firstGenres[1]); + } + + [Fact] + public async Task StringArray_ElementAtWithDefault_TranslatesToJsonExtract() + { + SeedData(); + await using var context = CreateContext(); + + var firstGenres = await context.Albums + .OrderBy(a => a.Name) + .Select(a => a.Genres != null && a.Genres.Length > 0 ? a.Genres[0] : "Unknown") + .ToListAsync(); + + Assert.Equal(3, firstGenres.Count); + Assert.Equal("Jazz", firstGenres[0]); + Assert.Equal("Unknown", firstGenres[1]); + Assert.Equal("Rock", firstGenres[2]); + } + + [Fact] + public async Task StringArray_Contains_TranslatesToLikePattern() + { + SeedData(); + await using var context = CreateContext(); + + var result = await context.Albums + .Where(a => a.Genres != null && a.Genres.Contains("Rock")) + .Select(a => a.Name) + .ToListAsync(); + + Assert.Single(result); + Assert.Equal("Rock Album", result[0]); + } + + [Fact] + public async Task StringArray_ContainsMultipleMatches_ReturnsAll() + { + SeedWithSharedGenre(); + await using var context = CreateContext(); + + var result = await context.Albums + .Where(a => a.Genres != null && a.Genres.Contains("Blues")) + .OrderBy(a => a.Name) + .Select(a => a.Name) + .ToListAsync(); + + Assert.Equal(2, result.Count); + Assert.Equal("Blues Album", result[0]); + Assert.Equal("Jazz Album", result[1]); + } + + [Fact] + public async Task StringArray_FlattenClientSide_WorksCorrectly() + { + SeedData(); + await using var context = CreateContext(); + + var albumGenres = await context.Albums + .AsNoTracking() + .Where(a => a.Genres != null && a.Genres.Length > 0) + .Select(a => a.Genres) + .ToListAsync(); + + var uniqueGenres = albumGenres + .Where(g => g != null) + .SelectMany(g => g!) + .Where(g => !string.IsNullOrWhiteSpace(g)) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToList(); + + Assert.Contains("Rock", uniqueGenres); + Assert.Contains("Metal", uniqueGenres); + Assert.Contains("Jazz", uniqueGenres); + Assert.Contains("Blues", uniqueGenres); + } + + private PrimCollContext CreateContext() + { + var optionsBuilder = new DbContextOptionsBuilder(); + optionsBuilder.UseDecentDB($"Data Source={_dbPath}"); + return new PrimCollContext(optionsBuilder.Options); + } + + private void SeedData() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "DROP TABLE IF EXISTS \"Albums\""; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """ + CREATE TABLE "Albums" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "Genres" TEXT + ) + """; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (1, 'Rock Album', '["Rock","Metal"]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (2, 'Jazz Album', '["Jazz","Blues"]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (3, 'No Genre Album', NULL)"""; + cmd.ExecuteNonQuery(); + } + + private void SeedEmptyArrayData() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "DROP TABLE IF EXISTS \"Albums\""; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """ + CREATE TABLE "Albums" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "Genres" TEXT + ) + """; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (1, 'Has Genres', '["Rock"]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (2, 'Empty Genres', '[]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (3, 'Null Genres', NULL)"""; + cmd.ExecuteNonQuery(); + } + + private void SeedWithSharedGenre() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "DROP TABLE IF EXISTS \"Albums\""; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """ + CREATE TABLE "Albums" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "Genres" TEXT + ) + """; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (1, 'Jazz Album', '["Jazz","Blues"]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (2, 'Blues Album', '["Blues","Soul"]')"""; + cmd.ExecuteNonQuery(); + cmd.CommandText = """INSERT INTO "Albums" ("Id", "Name", "Genres") VALUES (3, 'Rock Album', '["Rock","Metal"]')"""; + cmd.ExecuteNonQuery(); + } + + private static void TryDelete(string path) + { + if (File.Exists(path)) + File.Delete(path); + } + + private sealed class PrimCollContext(DbContextOptions options) : DbContext(options) + { + public DbSet Albums => Set(); + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(entity => + { + entity.ToTable("Albums"); + entity.HasKey(x => x.Id); + }); + } + } + + private sealed class Album + { + public int Id { get; set; } + public string Name { get; set; } = string.Empty; + public string[]? Genres { get; set; } + } +} diff --git a/bindings/dotnet/tests/DecentDB.Tests/SqlParameterRewriterTests.cs b/bindings/dotnet/tests/DecentDB.Tests/SqlParameterRewriterTests.cs new file mode 100644 index 0000000..cc4a9f1 --- /dev/null +++ b/bindings/dotnet/tests/DecentDB.Tests/SqlParameterRewriterTests.cs @@ -0,0 +1,72 @@ +using System.Collections.Generic; +using System.Data.Common; +using DecentDB.AdoNet; +using Xunit; + +namespace DecentDB.Tests; + +public class SqlParameterRewriterTests +{ + [Fact] + public void Rewrite_NamedParam_StartingWithDigit_IsRewritten() + { + var sql = "SELECT * FROM t WHERE id = @8__locals2_artistApiKey"; + var parameters = new List + { + new DecentDBParameter { ParameterName = "@8__locals2_artistApiKey", Value = 42 } + }; + + var (rewritten, paramMap) = SqlParameterRewriter.Rewrite(sql, parameters); + + Assert.DoesNotContain("@8__locals2", rewritten); + Assert.Contains("$", rewritten); + Assert.Single(paramMap); + } + + [Fact] + public void Rewrite_NamedParam_StartingWithLetter_IsRewritten() + { + var sql = "SELECT * FROM t WHERE id = @myParam"; + var parameters = new List + { + new DecentDBParameter { ParameterName = "@myParam", Value = 1 } + }; + + var (rewritten, paramMap) = SqlParameterRewriter.Rewrite(sql, parameters); + + Assert.DoesNotContain("@myParam", rewritten); + Assert.Contains("$", rewritten); + Assert.Single(paramMap); + } + + [Fact] + public void Rewrite_MultipleDigitPrefixedParams_EachGetUniqueIndex() + { + var sql = "SELECT * FROM t WHERE a = @8__locals1_x AND b = @8__locals2_y"; + var parameters = new List + { + new DecentDBParameter { ParameterName = "@8__locals1_x", Value = 1 }, + new DecentDBParameter { ParameterName = "@8__locals2_y", Value = 2 } + }; + + var (rewritten, paramMap) = SqlParameterRewriter.Rewrite(sql, parameters); + + Assert.DoesNotContain("@8__locals", rewritten); + Assert.Equal(2, paramMap.Count); + } + + [Fact] + public void Rewrite_SameDigitPrefixedParam_UsedTwice_GetsSameIndex() + { + var sql = "SELECT * FROM t WHERE a = @5__2 OR b = @5__2"; + var parameters = new List + { + new DecentDBParameter { ParameterName = "@5__2", Value = 99 } + }; + + var (rewritten, paramMap) = SqlParameterRewriter.Rewrite(sql, parameters); + + Assert.DoesNotContain("@5__2", rewritten); + Assert.Single(paramMap); + } +} diff --git a/bindings/dotnet/tests/DecentDB.Tests/SqliteCompatibilityTests.cs b/bindings/dotnet/tests/DecentDB.Tests/SqliteCompatibilityTests.cs index c1d1442..550c82e 100644 --- a/bindings/dotnet/tests/DecentDB.Tests/SqliteCompatibilityTests.cs +++ b/bindings/dotnet/tests/DecentDB.Tests/SqliteCompatibilityTests.cs @@ -1,4 +1,5 @@ using DecentDB.AdoNet; +using DecentDB.Native; using Xunit; namespace DecentDB.Tests; @@ -297,4 +298,641 @@ private static bool ScalarIsNull(DecentDBConnection conn, string sql) } #endregion + + #region EXISTS with derived-table JOINs + + [Fact] + public void ExistsSubquery_WithDerivedTableJoin_ReturnsTrue() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + + Exec(conn, """CREATE TABLE "AccessControls" ("Id" INTEGER NOT NULL, "LibraryId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """CREATE TABLE "GroupMembers" ("UserId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """INSERT INTO "AccessControls" ("Id", "LibraryId", "GroupId") VALUES (1, 100, 200)"""); + Exec(conn, """INSERT INTO "GroupMembers" ("UserId", "GroupId") VALUES (50, 200)"""); + + var result = ScalarBool(conn, """ + SELECT EXISTS ( + SELECT 1 + FROM "AccessControls" AS "a" + INNER JOIN ( + SELECT "g"."GroupId" + FROM "GroupMembers" AS "g" + WHERE "g"."UserId" = 50 + ) AS "sub" ON "a"."GroupId" = "sub"."GroupId" + WHERE "a"."LibraryId" = 100) + """); + Assert.True(result); + } + + [Fact] + public void ExistsSubquery_WithDerivedTableJoin_ReturnsFalse_WhenNoMatch() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + + Exec(conn, """CREATE TABLE "AccessControls" ("Id" INTEGER NOT NULL, "LibraryId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """CREATE TABLE "GroupMembers" ("UserId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """INSERT INTO "AccessControls" ("Id", "LibraryId", "GroupId") VALUES (1, 100, 200)"""); + Exec(conn, """INSERT INTO "GroupMembers" ("UserId", "GroupId") VALUES (50, 999)"""); + + var result = ScalarBool(conn, """ + SELECT EXISTS ( + SELECT 1 + FROM "AccessControls" AS "a" + INNER JOIN ( + SELECT "g"."GroupId" + FROM "GroupMembers" AS "g" + WHERE "g"."UserId" = 50 + ) AS "sub" ON "a"."GroupId" = "sub"."GroupId" + WHERE "a"."LibraryId" = 100) + """); + Assert.False(result); + } + + [Fact] + public void ExistsSubquery_WithDerivedTableJoin_AndParameters() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + + Exec(conn, """CREATE TABLE "AccessControls" ("Id" INTEGER NOT NULL, "LibraryId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """CREATE TABLE "GroupMembers" ("UserId" INTEGER NOT NULL, "GroupId" INTEGER NOT NULL)"""); + Exec(conn, """INSERT INTO "AccessControls" ("Id", "LibraryId", "GroupId") VALUES (1, 100, 200)"""); + Exec(conn, """INSERT INTO "GroupMembers" ("UserId", "GroupId") VALUES (50, 200)"""); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = """ + SELECT EXISTS ( + SELECT 1 + FROM "AccessControls" AS "a" + INNER JOIN ( + SELECT "g"."GroupId" + FROM "GroupMembers" AS "g" + WHERE "g"."UserId" = @userId + ) AS "sub" ON "a"."GroupId" = "sub"."GroupId" + WHERE "a"."LibraryId" = @libId) + """; + cmd.Parameters.Add(new DecentDBParameter("@userId", 50)); + cmd.Parameters.Add(new DecentDBParameter("@libId", 100)); + using var reader = cmd.ExecuteReader(); + Assert.True(reader.Read()); + Assert.True(reader.GetBoolean(0)); + } + + private static bool ScalarBool(DecentDBConnection conn, string sql) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + using var reader = cmd.ExecuteReader(); + reader.Read(); + return reader.GetBoolean(0); + } + + #endregion + + #region Type Affinity (ADR-0099) + + [Fact] + public void IntegerEqualsTextNumeric_ReturnsTrue() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE ta (id INTEGER)"); + Exec(conn, "INSERT INTO ta VALUES (42)"); + Assert.Equal(1, ScalarInt(conn, "SELECT COUNT(*) FROM ta WHERE id = '42'")); + } + + [Fact] + public void IntegerEqualsTextNonNumeric_ReturnsFalse() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE tb (id INTEGER)"); + Exec(conn, "INSERT INTO tb VALUES (42)"); + Assert.Equal(0, ScalarInt(conn, "SELECT COUNT(*) FROM tb WHERE id = 'abc'")); + } + + [Fact] + public void RealEqualsTextNumeric_ReturnsTrue() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE tc (val REAL)"); + Exec(conn, "INSERT INTO tc VALUES (3.14)"); + Assert.Equal(1, ScalarInt(conn, "SELECT COUNT(*) FROM tc WHERE val = '3.14'")); + } + + [Fact] + public void TextEqualsInteger_ReturnsTrue() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE td (name TEXT)"); + Exec(conn, "INSERT INTO td VALUES ('100')"); + Assert.Equal(1, ScalarInt(conn, "SELECT COUNT(*) FROM td WHERE name = 100")); + } + + [Fact] + public void IntegerLessThanNonNumericText_ReturnsTrue() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE te (id INTEGER)"); + Exec(conn, "INSERT INTO te VALUES (42)"); + Assert.Equal(1, ScalarInt(conn, "SELECT COUNT(*) FROM te WHERE id < 'zzz'")); + } + + [Fact] + public void IntegerEqualsTextParam_ReturnsTrue() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE tf (id INTEGER)"); + Exec(conn, "INSERT INTO tf VALUES (42)"); + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT COUNT(*) FROM tf WHERE id = @p0"; + cmd.Parameters.Add(new DecentDBParameter("@p0", "42")); + using var reader = cmd.ExecuteReader(); + reader.Read(); + Assert.Equal(1L, reader.GetInt64(0)); + } + + private static int ScalarInt(DecentDBConnection conn, string sql) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + using var reader = cmd.ExecuteReader(); + reader.Read(); + return (int)reader.GetInt64(0); + } + + #endregion + + #region Partial (Filtered) Unique Index + + [Fact] + public void PartialUniqueIndex_AllowsDuplicatesExcludedByPredicate() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Libraries (Id INTEGER PRIMARY KEY, Name TEXT NOT NULL, Type INTEGER NOT NULL)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """CREATE UNIQUE INDEX "IX_Libraries_Type" ON "Libraries" ("Type") WHERE "Type" != 3"""; + cmd.ExecuteNonQuery(); + + // Type=3 is excluded from the unique index โ€” duplicates should be allowed + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (11, 'Storage One', 3)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (12, 'Storage Two', 3)"; + cmd.ExecuteNonQuery(); // Should not throw + + Assert.Equal(2, ScalarInt(conn, "SELECT COUNT(*) FROM Libraries WHERE Type = 3")); + } + + [Fact] + public void PartialUniqueIndex_EnforcesDuplicatesMatchingPredicate() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Libraries (Id INTEGER PRIMARY KEY, Name TEXT NOT NULL, Type INTEGER NOT NULL)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """CREATE UNIQUE INDEX "IX_Libraries_Type" ON "Libraries" ("Type") WHERE "Type" != 3"""; + cmd.ExecuteNonQuery(); + + // Type=1 IS covered by the unique index โ€” duplicates should fail + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (21, 'Lib A', 1)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (22, 'Lib B', 1)"; + Assert.Throws(() => cmd.ExecuteNonQuery()); + } + + [Fact] + public void PartialUniqueIndex_MixedPredicateAndExcludedRows() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Libraries (Id INTEGER PRIMARY KEY, Name TEXT NOT NULL, Type INTEGER NOT NULL)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = """CREATE UNIQUE INDEX "IX_Libraries_Type" ON "Libraries" ("Type") WHERE "Type" != 3"""; + cmd.ExecuteNonQuery(); + + // Insert multiple Type=3 (excluded from unique constraint) + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (1, 'S1', 3)"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (2, 'S2', 3)"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (3, 'S3', 3)"; + cmd.ExecuteNonQuery(); + + // Insert unique Type=1 and Type=2 (covered by unique constraint) + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (4, 'A', 1)"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Libraries (Id, Name, Type) VALUES (5, 'B', 2)"; + cmd.ExecuteNonQuery(); + + Assert.Equal(5, ScalarInt(conn, "SELECT COUNT(*) FROM Libraries")); + Assert.Equal(3, ScalarInt(conn, "SELECT COUNT(*) FROM Libraries WHERE Type = 3")); + } + + #endregion + + #region UNION subquery column resolution + + [Fact] + public void UnionSubquery_NamedColumnProjection_ResolvesCorrectly() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Items (Id INTEGER PRIMARY KEY, Name TEXT, Category TEXT)"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Items VALUES (1, 'Alpha', 'A')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Items VALUES (2, 'Beta', 'B')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Items VALUES (3, 'Gamma', 'A')"; + cmd.ExecuteNonQuery(); + + // Qualified column references on UNION subquery alias + cmd.CommandText = """ + SELECT "u"."Id", "u"."Name" FROM ( + SELECT "t"."Id", "t"."Name" FROM "Items" AS "t" WHERE "t"."Category" = 'A' + UNION + SELECT "t0"."Id", "t0"."Name" FROM "Items" AS "t0" WHERE "t0"."Category" = 'B' + ) AS "u" ORDER BY "u"."Name" + """; + using var reader = cmd.ExecuteReader(); + var names = new List(); + while (reader.Read()) + names.Add(reader.GetString(1)); + + Assert.Equal(3, names.Count); + Assert.Equal("Alpha", names[0]); + Assert.Equal("Beta", names[1]); + Assert.Equal("Gamma", names[2]); + } + + [Fact] + public void UnionSubquery_CountOverNamedColumns_Works() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Settings (Id INTEGER PRIMARY KEY, Key TEXT, Value TEXT)"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Settings VALUES (1, 'validation.min', '5')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Settings VALUES (2, 'conversion.format', 'mp3')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = "INSERT INTO Settings VALUES (3, 'other.setting', 'yes')"; + cmd.ExecuteNonQuery(); + + // COUNT with LIKE parameters over UNION โ€” pattern from EF Core OR filters + cmd.CommandText = """ + SELECT COUNT(*) FROM ( + SELECT "s"."Id", "s"."Key", "s"."Value" FROM "Settings" AS "s" WHERE "s"."Key" LIKE @p0 + UNION + SELECT "s0"."Id", "s0"."Key", "s0"."Value" FROM "Settings" AS "s0" WHERE "s0"."Key" LIKE @p1 + ) AS "u" + """; + cmd.Parameters.Add(new DecentDBParameter("@p0", "%validation%")); + cmd.Parameters.Add(new DecentDBParameter("@p1", "%conversion%")); + Assert.Equal(2L, cmd.ExecuteScalar()); + } + + [Fact] + public void UnionSubquery_WithLimitOffset_ResolvesColumns() + { + using var conn = new DecentDBConnection($"Data Source={_dbPath}"); + conn.Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "CREATE TABLE Records (Id INTEGER PRIMARY KEY, Tag TEXT)"; + cmd.ExecuteNonQuery(); + for (int i = 1; i <= 5; i++) + { + cmd.CommandText = $"INSERT INTO Records VALUES ({i}, 'tag{i}')"; + cmd.ExecuteNonQuery(); + } + + cmd.CommandText = """ + SELECT "u"."Id", "u"."Tag" FROM ( + SELECT "r"."Id", "r"."Tag" FROM "Records" AS "r" WHERE "r"."Id" <= 3 + UNION + SELECT "r0"."Id", "r0"."Tag" FROM "Records" AS "r0" WHERE "r0"."Id" >= 4 + ) AS "u" ORDER BY "u"."Id" LIMIT @p0 OFFSET @p1 + """; + cmd.Parameters.Add(new DecentDBParameter("@p0", 3)); + cmd.Parameters.Add(new DecentDBParameter("@p1", 1)); + using var reader = cmd.ExecuteReader(); + var ids = new List(); + while (reader.Read()) + ids.Add(reader.GetInt64(0)); + + Assert.Equal([2L, 3L, 4L], ids); + } + + #endregion + + #region Composite Primary Key + + [Fact] + public void CompositePrimaryKey_WhereOnFirstColumn_ReturnsAllMatches() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE cpk1 (a INTEGER NOT NULL, b INTEGER NOT NULL, name TEXT, PRIMARY KEY (a, b))"); + Exec(conn, "INSERT INTO cpk1 VALUES (1,1,'a1b1')"); + Exec(conn, "INSERT INTO cpk1 VALUES (1,2,'a1b2')"); + Exec(conn, "INSERT INTO cpk1 VALUES (2,1,'a2b1')"); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT name FROM cpk1 WHERE a = 1 ORDER BY b"; + var names = new List(); + using (var r = cmd.ExecuteReader()) while (r.Read()) names.Add(r.GetString(0)); + Assert.Equal(["a1b1", "a1b2"], names); + } + + [Fact] + public void CompositePrimaryKey_WhereOnBothColumns_ReturnsSingleRow() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE cpk2 (a INTEGER NOT NULL, b INTEGER NOT NULL, name TEXT, PRIMARY KEY (a, b))"); + Exec(conn, "INSERT INTO cpk2 VALUES (1,1,'a1b1')"); + Exec(conn, "INSERT INTO cpk2 VALUES (1,2,'a1b2')"); + Exec(conn, "INSERT INTO cpk2 VALUES (2,1,'a2b1')"); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT name FROM cpk2 WHERE a = 1 AND b = 2"; + var names = new List(); + using (var r = cmd.ExecuteReader()) while (r.Read()) names.Add(r.GetString(0)); + Assert.Equal(["a1b2"], names); + } + + [Fact] + public void CompositePrimaryKey_WhereOnSecondColumn_ReturnsCorrectRows() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE cpk3 (a INTEGER NOT NULL, b INTEGER NOT NULL, name TEXT, PRIMARY KEY (a, b))"); + Exec(conn, "INSERT INTO cpk3 VALUES (1,1,'a1b1')"); + Exec(conn, "INSERT INTO cpk3 VALUES (1,2,'a1b2')"); + Exec(conn, "INSERT INTO cpk3 VALUES (2,1,'a2b1')"); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT name FROM cpk3 WHERE b = 1 ORDER BY a"; + var names = new List(); + using (var r = cmd.ExecuteReader()) while (r.Read()) names.Add(r.GetString(0)); + Assert.Equal(["a1b1", "a2b1"], names); + } + + #endregion + + #region GROUP BY with ORDER BY on aggregate + + [Fact] + public void GroupBy_OrderByCountDesc_ReturnsCorrectOrder() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE gb_plays (user_id INTEGER, song_id INTEGER)"); + Exec(conn, "INSERT INTO gb_plays VALUES (1, 1)"); + Exec(conn, "INSERT INTO gb_plays VALUES (1, 1)"); + Exec(conn, "INSERT INTO gb_plays VALUES (1, 2)"); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT song_id, COUNT(*) as cnt FROM gb_plays WHERE user_id = 1 GROUP BY song_id ORDER BY cnt DESC"; + using var reader = cmd.ExecuteReader(); + + var rows = new List<(long songId, long count)>(); + while (reader.Read()) + { + rows.Add((reader.GetInt64(0), reader.GetInt64(1))); + } + + Assert.Equal(2, rows.Count); + Assert.Equal(1, rows[0].songId); + Assert.Equal(2, rows[0].count); + Assert.Equal(2, rows[1].songId); + Assert.Equal(1, rows[1].count); + } + + [Fact] + public void GroupBy_OrderByRawCountDesc_ReturnsCorrectOrder() + { + using var conn = Open(); + Exec(conn, "CREATE TABLE gb_plays2 (user_id INTEGER, song_id INTEGER)"); + Exec(conn, "INSERT INTO gb_plays2 VALUES (1, 1)"); + Exec(conn, "INSERT INTO gb_plays2 VALUES (1, 1)"); + Exec(conn, "INSERT INTO gb_plays2 VALUES (1, 2)"); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT song_id, COUNT(*) as cnt FROM gb_plays2 WHERE user_id = 1 GROUP BY song_id ORDER BY COUNT(*) DESC"; + using var reader = cmd.ExecuteReader(); + + var rows = new List<(long songId, long count)>(); + while (reader.Read()) + { + rows.Add((reader.GetInt64(0), reader.GetInt64(1))); + } + + Assert.Equal(2, rows.Count); + Assert.Equal(1, rows[0].songId); + Assert.Equal(2, rows[0].count); + Assert.Equal(2, rows[1].songId); + Assert.Equal(1, rows[1].count); + } + + #endregion + + #region DELETE with subquery parameters + + [Fact] + public void Delete_WithExistsSubqueryParam_BindsCorrectly() + { + using var conn = Open(); + Exec(conn, @" + CREATE TABLE del_artists (""Id"" INTEGER PRIMARY KEY, ""LibraryId"" INTEGER, ""Name"" TEXT); + CREATE TABLE del_contributors (""Id"" INTEGER PRIMARY KEY, ""ArtistId"" INTEGER, ""Name"" TEXT); + INSERT INTO del_artists VALUES (1, 10, 'Artist1'); + INSERT INTO del_artists VALUES (2, 20, 'Artist2'); + INSERT INTO del_contributors VALUES (1, 1, 'Contrib1'); + INSERT INTO del_contributors VALUES (2, 2, 'Contrib2'); + "); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = @" + DELETE FROM ""del_contributors"" + WHERE EXISTS ( + SELECT 1 FROM ""del_contributors"" AS ""c"" + INNER JOIN ""del_artists"" AS ""a"" ON ""c"".""ArtistId"" = ""a"".""Id"" + WHERE ""a"".""LibraryId"" = @p0 + AND ""del_contributors"".""Id"" = ""c"".""Id"" + )"; + var p = cmd.CreateParameter(); + p.ParameterName = "@p0"; + p.Value = 10L; + cmd.Parameters.Add(p); + var deleted = cmd.ExecuteNonQuery(); + + Assert.Equal(1, deleted); + + // Verify only contrib for library 20 remains + cmd.Parameters.Clear(); + cmd.CommandText = @"SELECT COUNT(*) FROM ""del_contributors"""; + var remaining = (long)cmd.ExecuteScalar()!; + Assert.Equal(1, remaining); + } + + [Fact] + public void Delete_WithInSubqueryParam_BindsCorrectly() + { + using var conn = Open(); + Exec(conn, @" + CREATE TABLE del2_artists (""Id"" INTEGER PRIMARY KEY, ""LibraryId"" INTEGER, ""Name"" TEXT); + CREATE TABLE del2_contributors (""Id"" INTEGER PRIMARY KEY, ""ArtistId"" INTEGER, ""Name"" TEXT); + INSERT INTO del2_artists VALUES (1, 10, 'Artist1'); + INSERT INTO del2_artists VALUES (2, 20, 'Artist2'); + INSERT INTO del2_contributors VALUES (1, 1, 'Contrib1'); + INSERT INTO del2_contributors VALUES (2, 2, 'Contrib2'); + "); + + using var cmd = conn.CreateCommand(); + cmd.CommandText = @" + DELETE FROM ""del2_contributors"" + WHERE ""del2_contributors"".""Id"" IN ( + SELECT ""c0"".""Id"" + FROM ""del2_contributors"" AS ""c0"" + LEFT JOIN ""del2_artists"" AS ""a"" ON ""c0"".""ArtistId"" = ""a"".""Id"" + WHERE ""a"".""Id"" IS NOT NULL AND ""a"".""LibraryId"" = @p0 + )"; + var p = cmd.CreateParameter(); + p.ParameterName = "@p0"; + p.Value = 10L; + cmd.Parameters.Add(p); + var deleted = cmd.ExecuteNonQuery(); + + Assert.Equal(1, deleted); + + cmd.Parameters.Clear(); + cmd.CommandText = @"SELECT COUNT(*) FROM ""del2_contributors"""; + var remaining = (long)cmd.ExecuteScalar()!; + Assert.Equal(1, remaining); + } + #endregion + + #region JSON Functions + + [Fact] + public void JsonArrayLength_ReturnsElementCount() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = @"SELECT json_array_length('[1,2,3]')"; + Assert.Equal(3L, (long)cmd.ExecuteScalar()!); + + cmd.CommandText = @"SELECT json_array_length('[]')"; + Assert.Equal(0L, (long)cmd.ExecuteScalar()!); + + cmd.CommandText = @"SELECT json_array_length('[""a"",""b""]')"; + Assert.Equal(2L, (long)cmd.ExecuteScalar()!); + } + + [Fact] + public void JsonArrayLength_NullInput_ReturnsNull() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_array_length(NULL)"; + Assert.True(cmd.ExecuteScalar() is DBNull); + } + + [Fact] + public void JsonArrayLength_NonArrayInput_ReturnsZero() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_array_length('{""key"":""value""}')"; + Assert.Equal(0L, (long)cmd.ExecuteScalar()!); + } + + [Fact] + public void JsonArrayLength_WithPath_ReturnsNestedArrayCount() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_array_length('{""items"":[1,2,3,4]}', '$.items')"; + Assert.Equal(4L, (long)cmd.ExecuteScalar()!); + } + + [Fact] + public void JsonArrayLength_OnColumn_WorksCorrectly() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"CREATE TABLE json_test (id INTEGER PRIMARY KEY, data TEXT)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = @"INSERT INTO json_test (id, data) VALUES (1, '[""rock"",""pop""]')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = @"INSERT INTO json_test (id, data) VALUES (2, '[""jazz""]')"; + cmd.ExecuteNonQuery(); + cmd.CommandText = @"INSERT INTO json_test (id, data) VALUES (3, NULL)"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = @"SELECT id FROM json_test WHERE json_array_length(data) > 1"; + Assert.Equal(1L, (long)cmd.ExecuteScalar()!); + } + + [Fact] + public void JsonExtract_ReturnsValueAtPath() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = @"SELECT json_extract('[""rock"",""pop"",""jazz""]', '$[0]')"; + Assert.Equal("rock", cmd.ExecuteScalar()!.ToString()); + + cmd.CommandText = @"SELECT json_extract('[""rock"",""pop"",""jazz""]', '$[2]')"; + Assert.Equal("jazz", cmd.ExecuteScalar()!.ToString()); + } + + [Fact] + public void JsonExtract_ObjectKey_ReturnsValue() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_extract('{""name"":""test"",""count"":42}', '$.name')"; + Assert.Equal("test", cmd.ExecuteScalar()!.ToString()); + + cmd.CommandText = @"SELECT json_extract('{""name"":""test"",""count"":42}', '$.count')"; + Assert.Equal(42L, (long)cmd.ExecuteScalar()!); + } + + [Fact] + public void JsonExtract_NullInput_ReturnsNull() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_extract(NULL, '$[0]')"; + Assert.True(cmd.ExecuteScalar() is DBNull); + } + + [Fact] + public void JsonExtract_OutOfBounds_ReturnsNull() + { + using var conn = Open(); + using var cmd = conn.CreateCommand(); + cmd.CommandText = @"SELECT json_extract('[""a""]', '$[5]')"; + Assert.True(cmd.ExecuteScalar() is DBNull); + } + + #endregion } diff --git a/decentdb.nimble b/decentdb.nimble index 06a4b04..8bedb79 100644 --- a/decentdb.nimble +++ b/decentdb.nimble @@ -1,4 +1,4 @@ -version = "1.3.0" +version = "1.4.0" author = "DecentDB contributors" description = "DecentDB engine" license = "Apache-2.0" @@ -11,7 +11,7 @@ requires "zip >= 0.3.1" task build_lib, "Build DecentDB shared library (C API)": - exec "nim c --app:lib -d:libpg_query -d:release --mm:arc --threads:on --outdir:build src/c_api.nim" + exec "nim c --app:lib -d:libpg_query -d:release --mm:arc --threads:on -d:noSignalHandler -d:useMalloc --outdir:build src/c_api.nim" task test_bindings_dotnet, "Run .NET binding tests": exec "ln -sf libc_api.so build/libdecentdb.so" diff --git a/design/INMEMORY_SUPPORT_PLAN.md b/design/INMEMORY_SUPPORT_PLAN.md new file mode 100644 index 0000000..69cea63 --- /dev/null +++ b/design/INMEMORY_SUPPORT_PLAN.md @@ -0,0 +1,170 @@ +# In-Memory Database Support Plan + +## 1. Overview +The goal is to support in-memory databases in DecentDB, similar to SQLite's `:memory:`. +The most performant yet simplest way to achieve this is to implement an `InMemoryVfs` (Virtual File System) that stores file contents in memory. +This approach requires zero architectural changes to the core engine (`Pager`, `Wal`, `BTree`), preserves Snapshot Isolation (concurrent readers), and bounds memory usage perfectly. + +## 2. Architecture +- **`MemVfs`**: A new VFS implementation that stores files in memory. +- **`MemVfsFile`**: A subclass of `VfsFile` that holds a `string` or `seq[byte]` buffer. +- **`Vfs` Interface Expansion**: Add `getFileSize`, `fileExists`, and `removeFile` to the `Vfs` interface to remove direct `os` module dependencies in the engine. +- **`openDb`**: Modified to detect `:memory:` and instantiate a `MemVfs` instead of `OsVfs`. + +## 3. Detailed Design + +### 3.1 Refactoring `VfsFile` +Currently, `VfsFile` is a concrete type containing a `File` handle: +```nim +type VfsFile* = ref object + path*: string + file*: File # <-- This must move to a subclass + lock*: Lock + bufferedDirty*: Atomic[bool] +``` +We will refactor it to be an extensible base class: +```nim +type VfsFile* = ref object of RootObj + path*: string + lock*: Lock + bufferedDirty*: Atomic[bool] +``` +`OsVfs` will define `OsVfsFile` inheriting from `VfsFile` and adding the `file: File` field. +`MemVfs` will define `MemVfsFile` inheriting from `VfsFile` and adding the `data: string` field. + +#### 3.1.1 Affected Files +The following files directly access `VfsFile.file` and must be updated to cast to `OsVfsFile`: +- `src/vfs/os_vfs.nim` - All VFS methods that access `file.file` +- `src/engine.nim` - Line 34 defines `file*: VfsFile` in `Db` object +- `src/pager/pager.nim` - Line 45 defines `file*: VfsFile` in `Pager` object +- `src/pager/db_header.nim` - Lines 235, 244 use `VfsFile` in `readHeader`/`writeHeader` +- `src/wal/wal.nim` - Line 70 defines `file*: VfsFile` in `Wal` object + +**Migration Strategy**: Since `VfsFile` methods receive the file as a parameter, the VFS implementation (OsVfs or MemVfs) will cast the `VfsFile` to its concrete type internally. No changes required in callers. + +### 3.2 Expanding the `Vfs` Interface +Currently, `engine.nim`, `pager.nim`, and `wal.nim` use `os.getFileInfo`, `os.fileExists`, and `os.removeFile` directly. **Run `grep -r 'os\\.(getFileInfo|fileExists|removeFile|fileSize)' src/` before implementation to confirm all locations.** + +These must be abstracted into the `Vfs` interface: +```nim +method getFileSize*(vfs: Vfs, path: string): Result[int64] {.base.} +method fileExists*(vfs: Vfs, path: string): bool {.base.} +method removeFile*(vfs: Vfs, path: string): Result[Void] {.base.} +``` +`OsVfs` will implement these using the `os` module. `MemVfs` will implement these by checking its internal `files` table. + +### 3.3 Implementing `MemVfs` +Create `src/vfs/mem_vfs.nim`: +```nim +type MemVfsFile* = ref object of VfsFile + data*: seq[byte] + +type MemVfs* = ref object of Vfs + files*: Table[string, MemVfsFile] + vfsLock*: Lock +``` +- `open`: If `create` is true, create a new `MemVfsFile` and add it to `files`. If false, look it up. +- `read` / `readStr`: Acquire the file's lock, `copyMem` from `data` to the buffer. +- `write` / `writeStr`: Acquire the file's lock, resize `data` if `offset + len > data.len`, `copyMem` from the buffer to `data`. +- `fsync`: No-op (return `okVoid()`). +- `truncate`: Acquire the file's lock, resize `data`. +- `close`: Remove the file from the `files` table and release memory. This ensures prompt cleanup when `Db` is closed. +- `getFileSize`: Return `data.len` as `int64`. Returns `0` if file doesn't exist or `data` is empty. +- `fileExists`: Check if path exists in `files` table. +- `removeFile`: Remove the file from `files` table, releasing memory. +- `supportsMmap`: Return `false`. Memory-mapped I/O is not supported for in-memory files. +- `mapWritable` / `unmap`: Return error `ERR_INTERNAL` - not supported. + +#### 3.3.1 Memory Management +- Files are owned by the `MemVfs` instance. +- When `close()` is called, the file is removed from `files` table immediately (not deferred to GC). +- When the `Db` object is closed, it should call `vfs.close()` on all open files, then the `MemVfs` itself can be GC'd. +- This ensures deterministic memory release, important for test scenarios that create/destroy many in-memory databases. + +### 3.4 Modifying `openDb` +In `src/engine.nim`: +```nim +proc openDb*(path: string, cachePages: int = 1024): Result[Db] = + let isMemory = path.endsWith(":memory:") + let vfs: Vfs = if isMemory: newMemVfs() else: newOsVfs() + ... +``` +Replace all direct `os` calls (`getFileInfo`, `fileExists`, `removeFile`) with `vfs.getFileSize`, `vfs.fileExists`, and `vfs.removeFile`. + +### 3.5 WAL Handling for `:memory:` +**Decision**: WAL will NOT be bypassed for `:memory:` databases in v1. + +**Rationale**: +1. **Simplicity**: No changes required to WAL, Pager, or recovery code paths. +2. **Correctness**: Preserves Snapshot Isolation semantics exactly as they work for disk-based databases. +3. **Test coverage**: Existing WAL tests will automatically cover in-memory databases. + +The `MemVfs` simulates the filesystem safely, so the WAL will function correctly in memory. The overhead of writing to a memory-backed WAL and checkpointing to a memory-backed DB is minimal (just `memcpy` operations). + +**Future Optimization**: A `DurabilityMode.dmNone` fast-path could be added later to bypass WAL entirely for `:memory:` databases, reducing memory overhead by ~50% (no double-buffering of data in both DB and WAL). This would require: +- ADR documenting the trade-offs +- Changes to `Wal.nim` to skip WAL writes when `dmNone` is set +- Changes to `Pager.nim` to write directly to the DB file +- Testing to ensure Snapshot Isolation still works correctly + +## 4. Performance and Memory Considerations +- **Memory Usage**: The memory usage is bounded to `Size of DB in MemVfs` + `Size of WAL in MemVfs` + `Pager Cache Size (default 4MB)`. The WAL file size is bounded by the auto-checkpoint interval (default 64MB). This means the overhead of double-buffering is minimal and perfectly acceptable for an in-memory database. +- **CPU Overhead**: Reading and writing to `MemVfs` involves `memcpy`, which is extremely fast (microseconds per page). +- **Concurrency**: Because the WAL is still used, Snapshot Isolation and concurrent readers work exactly as they do for disk-based databases. + +## 4.1 Risks and Mitigations +| Risk | Impact | Mitigation | +|------|--------|------------| +| Memory exhaustion | Process OOM if in-memory DB grows unbounded | Document that `:memory:` is for ephemeral workloads; consider adding optional memory limits in future | +| GC pressure | Large in-memory DBs cause GC pauses | Use `seq[byte]` instead of `string` for binary data; pre-allocate buffers where possible | +| VfsFile refactoring breaks existing code | Compilation errors in dependent modules | Incremental migration: first add base class, then update OsVfs, then add MemVfs | +| mmap not supported | Performance regression if engine relies on mmap | Engine already handles `supportsMmap() == false` gracefully (used for fault-injection tests) | + +## 4.2 ADR Requirement +Per `design/adr/README.md`, an ADR is required for decisions that "have meaningful trade-offs that future contributors will need to understand." + +**ADR Required**: Yes. Create `design/adr/NNNN-inmemory-vfs-design.md` documenting: +1. The decision to use inheritance-based VFS extensibility vs. composition +2. The decision to keep WAL enabled for `:memory:` in v1 +3. The trade-offs of memory overhead vs. correctness guarantees +4. The mmap support decision (not supported for MemVfs) + +## 5. Testing Strategy +- Add unit tests for `MemVfs` in `tests/vfs/test_mem_vfs.nim`. +- Add a test in `tests/test_engine.nim` that opens `:memory:`, creates tables, inserts data, and verifies concurrent readers work. +- Add a test in `tests/test_engine.nim` that verifies multiple independent `:memory:` databases do not share state.RSS drops via `getrusage` or equivalent +- Add a test in `tests/test_engine.nim` that verifies transaction rollback works correctly in `:memory:` (insert data, rollback, verify data is gone). +- Add a test in `tests/test_engine.nim` that verifies memory is released when `Db` is closed (open `:memory:`, insert large data, close, open again, verify memory usage dropped). +- Run core benchmark smoke tests against `:memory:`. +- Ensure that closing the `Db` object properly frees the `MemVfs` memory (no memory leaks). + +## 5.1 Connection String Handling +The following connection string patterns will be supported: +- `:memory:` - Standard in-memory database (each connection gets a new isolated instance) + +**Not supported in v1** (deferred to future): +- `file::memory:?cache=shared` - Shared in-memory database across connections (requires global registry) +- Named in-memory databases (e.g., `file:mydb?mode=memory`) - Requires additional state management + +The `:memory:` pattern matches SQLite's behavior for simplicity and familiarity. + +## 6. Documentation Updates +- Update `docs/getting-started/README.md` and `docs/api/README.md` to explain the `:memory:` connection string. +- Clearly state that in-memory databases are process-bound, destroyed when closed, ideal for tests/ephemeral workloads, and durability guarantees apply only to the process lifetime. +- Document that `:memory:` databases do not support mmap (no performance impact expected). +- Add example code showing how to use `:memory:` for unit testing. + +## 7. Implementation Checklist +- [ ] Create ADR for in-memory VFS design decisions +- [ ] Refactor `VfsFile` to be a base class +- [ ] Create `OsVfsFile` inheriting from `VfsFile` +- [ ] Update `OsVfs` to use `OsVfsFile` +- [ ] Add `getFileSize`, `fileExists`, `removeFile` to `Vfs` interface +- [ ] Implement `OsVfs` methods for new interface +- [ ] Create `MemVfs` and `MemVfsFile` +- [ ] Update `openDb` to detect `:memory:` +- [ ] Replace direct `os` calls with VFS methods in engine/pager/wal +- [ ] Add unit tests for `MemVfs` +- [ ] Add integration tests for `:memory:` databases +- [ ] Update documentation +- [ ] Run full test suite to verify no regressions diff --git a/design/adr/0096-case-insensitive-identifiers.md b/design/adr/0096-case-insensitive-identifiers.md new file mode 100644 index 0000000..bd4f254 --- /dev/null +++ b/design/adr/0096-case-insensitive-identifiers.md @@ -0,0 +1,52 @@ +# ADR-0096: Case-Insensitive Identifier Resolution + +**Status:** Accepted +**Date:** 2025-07-18 +**Context:** Bug fix โ€” catalog and binder performed case-sensitive comparisons despite using pg_query which follows PostgreSQL identifier folding rules. + +## Problem + +DecentDB uses `libpg_query` (ADR-0035) to parse SQL. PostgreSQL's parser lowercases all unquoted identifiers โ€” `SELECT Name FROM Users` becomes `select name from users`. Quoted identifiers preserve case โ€” `SELECT "Name" FROM "Users"` keeps original casing. + +DecentDB's catalog stored identifiers with their original case (as received from the parser) and compared them **case-sensitively**. This created a mismatch: + +1. `CREATE TABLE "Users" ("Name" TEXT)` โ†’ catalog stores table `Users`, column `Name` +2. `SELECT Name FROM Users` โ†’ parser produces `name`, `users` (lowercased) +3. Catalog lookup for `users` fails because map key is `Users` + +This is a correctness bug, not a feature request. Any SQL tool that mixes quoted and unquoted identifiers (EF Core, raw SQL, any PostgreSQL-compatible client) hits this. + +## Decision + +Normalize identifier comparisons to case-insensitive at the **comparison point**, not at the storage point. + +### Why comparison-time, not storage-time? + +- **No persistent format change** โ€” column and table names stored on disk keep their original case +- **Backward compatible** โ€” existing databases load correctly without migration +- **Display fidelity** โ€” `DESCRIBE TABLE` shows the name the user originally chose +- **PostgreSQL semantics** โ€” PostgreSQL also stores the original (folded) name and compares case-insensitively for unquoted identifiers + +### What changed + +| Module | Change | Hot path? | +|--------|--------|-----------| +| `catalog.nim` | Map keys (tables, views, indexes) normalized via `normalizedObjectName()` | Once per statement | +| `binder.nim` | `resolveColumn()` pre-normalizes lookup name, compares via `normalizedName()` | Once per column ref per statement | +| `binder.nim` | DDL/conflict column matching uses `eqIdent()` helper | DDL only (infrequent) | +| `engine.nim` | `columnIndexInTable()` pre-normalizes lookup name | Per column lookup in execution | +| `sql.nim` | CREATE TABLE constraint column matching lowercased | DDL only | + +### Performance impact + +- `toLowerAscii()` on identifier strings (typically <64 chars): sub-microsecond +- Lookup keys are normalized once outside loops, not per iteration +- `eqIdent()` is `{.inline.}` and used only in DDL paths +- No measurable impact on query throughput + +## Consequences + +- Mixed quoting now works: `CREATE TABLE "Users" (...)` + `SELECT * FROM Users` succeeds +- EF Core (which quotes all identifiers) and raw SQL (which typically doesn't quote) interoperate correctly +- All bindings (Python, .NET, Go, Node) benefit from the fix +- No database migration needed โ€” existing files work unchanged diff --git a/design/adr/0097-shared-library-embedding-safety.md b/design/adr/0097-shared-library-embedding-safety.md new file mode 100644 index 0000000..e96251a --- /dev/null +++ b/design/adr/0097-shared-library-embedding-safety.md @@ -0,0 +1,48 @@ +## Shared Library Embedding Safety +**Date:** 2026-02-22 +**Status:** Accepted + +### Decision + +When building DecentDB as a shared library (`build_lib` task), apply three compile-time and runtime changes to ensure safe embedding in host runtimes (.NET, JVM, Python, Go, etc.): + +1. **`-d:noSignalHandler`** โ€” Disable Nim's built-in signal handler +2. **`-d:useMalloc`** โ€” Use the system allocator instead of Nim's thread-local allocator +3. **Pager eviction on `closeDb()`** โ€” Evict stale `Pager` references from threadvar caches + +These flags apply only to the shared library build target (`nimble build_lib`), not to the standalone CLI or test binaries. + +### Rationale + +**Signal handler conflict (`noSignalHandler`):** +Nim's runtime installs a SIGSEGV handler that prints a stack trace and aborts. Host runtimes (.NET CLR, JVM HotSpot) also use SIGSEGV-class signals internally for GC write barriers, null-reference traps, and stack probing. Two competing signal handlers on the same signal cause non-deterministic crashes. Disabling Nim's handler lets the host runtime manage signals as it expects. + +**Thread-local allocator (`useMalloc`):** +Nim's default allocator (`nimAllocPagesViaMmap`) maintains per-thread free lists. Host runtimes like .NET's async/await model routinely allocate objects on one OS thread and free them on another (task continuation on a different thread-pool thread). This is safe with system `malloc`/`free` (which are thread-safe) but causes heap corruption with Nim's thread-local allocator. Using `-d:useMalloc` delegates all allocation to the system allocator. + +**Pager eviction (`closeDb()`):** +DecentDB uses three threadvar caches for performance: `gAppendCache` (B-tree append optimization), `gReusableBTree` (avoids per-insert BTree allocation), and `gEvalPager`/`gEvalCatalog` (execution context reuse). These hold raw `Pager` pointers. Under ARC, when a database is closed and its `Pager` deallocated, these threadvar entries become dangling references. On the next `openDb()`, the thread-local cache still points to freed memory. In host runtimes that reuse threads (thread pools), this leads to use-after-free. The fix evicts all entries referencing a specific `Pager` during `closeDb()`. + +### Alternatives Considered + +1. **Apply flags globally (all build targets):** Rejected โ€” standalone CLI and tests benefit from Nim's signal handler (crash diagnostics) and thread-local allocator (faster allocation). The embedding issues only manifest when a host runtime controls the process. + +2. **Weak references for threadvar caches:** Nim lacks native weak references under ARC. Simulating them (ref counting wrapper + nil check) adds complexity for minimal gain when explicit eviction is straightforward and correct. + +3. **Remove threadvar caches entirely:** Would eliminate the dangling reference problem but would regress insert performance (the append cache and reusable BTree are measurable optimizations on bulk workloads). + +### Trade-offs + +| Aspect | Impact | +|--------|--------| +| Allocation performance | `-d:useMalloc` is slightly slower than Nim's arena allocator for small, frequent allocations. Unmeasurable in benchmarks for DecentDB's workload (I/O-dominated). | +| Crash diagnostics | No Nim stack trace on SIGSEGV in the shared library. Host runtime's crash handler takes over (typically provides better diagnostics anyway). | +| `closeDb()` cost | Three threadvar scans add ~microseconds to close. Negligible vs. fsync/flush. | +| Standalone builds | Unaffected โ€” flags are only in `build_lib`, not `nimble test` or CLI targets. | + +### References + +- Nim manual: [noSignalHandler](https://nim-lang.org/docs/nimc.html) +- Nim manual: [useMalloc](https://nim-lang.org/docs/nimc.html) +- ADR-0011: Memory Management Strategy +- ADR-0025: Memory Leak Prevention Strategy diff --git a/design/adr/0098-left-join-subquery-column-resolution.md b/design/adr/0098-left-join-subquery-column-resolution.md new file mode 100644 index 0000000..ba3868a --- /dev/null +++ b/design/adr/0098-left-join-subquery-column-resolution.md @@ -0,0 +1,46 @@ +## LEFT JOIN Subquery Column Resolution +**Date:** 2026-02-22 +**Status:** Accepted + +### Decision + +Fix two related bugs in `exec.nim` that caused incorrect column resolution when a LEFT JOIN's right side is a subquery that returns zero matching rows: + +1. **`columnIndex` qualified lookup:** When a qualified column reference (`table.column`) fails the direct lookup, do not fall through to unqualified matching across all tables. Instead, attempt unqualified matching only against columns with no table prefix, preventing false "Ambiguous column" errors. + +2. **Subquery NULL padding:** When a LEFT JOIN's right side is a `pkSubqueryScan` that returns zero rows, derive the right-side column names from the inner plan's projection list (or, as a last resort, from the underlying table scan's catalog entry). This allows the LEFT JOIN to correctly pad NULL values for all right-side columns. + +### Rationale + +**The problem:** +SQL standard requires LEFT JOIN to produce NULL values for all right-side columns when no matching row exists. DecentDB's join execution populates `rightColumns` (the list of right-side column names) from the first result row. When zero rows match, `rightColumns` stays empty. A catalog-based fallback exists for real-table joins (`plan.right.table` โ†’ catalog lookup), but for subquery-backed joins (`pkSubqueryScan`), `plan.right.table` is often empty โ€” the subquery is anonymous or aliased. With no column names for the right side, the outer SELECT cannot resolve references to right-side columns, producing "Unknown column" or "Ambiguous column" errors. + +This affects any SQL that LEFT JOINs on a subquery where the right side might match zero rows โ€” a common pattern generated by ORMs (Entity Framework Core, SQLAlchemy, Django ORM) and hand-written SQL alike. + +**`columnIndex` fix:** +Previously, when qualified lookup (`u0.Id`) failed, the code fell through to unqualified matching that searched ALL columns for `Id`. In a multi-table JOIN, `Id` exists in multiple tables, producing a false "Ambiguous column" error. The fix restricts the fallback: when a table qualifier was provided and the qualified lookup fails, only match columns with no table prefix (bare `Id`), not `other_table.Id`. + +**Subquery NULL padding fix:** +When `rightColumns` is empty and the right plan is `pkSubqueryScan`, traverse the inner plan tree to find the `pkProject` node (which determines output columns). Extract column names from the projection's aliases or column expressions, prefixed with the subquery's alias. If no projection is found, fall back to the innermost `pkTableScan`'s catalog columns. + +### Alternatives Considered + +1. **Store column metadata in Row type:** Add a "schema" field to `Row` that carries column names even for empty result sets. Rejected โ€” this is a fundamental change to the row model affecting all execution paths, with broader testing and performance implications than a targeted fix. + +2. **Execute the subquery unconditionally and capture column names from the plan before filtering:** Rejected โ€” the subquery may already be filtered at the plan level; re-executing without filters would change semantics. + +3. **Require all subquery JOINs to have a named table:** Not feasible โ€” the SQL standard and ORM-generated SQL both allow anonymous subqueries in JOINs. + +### Trade-offs + +| Aspect | Impact | +|--------|--------| +| Correctness | LEFT JOIN on subqueries now returns correct NULL-padded rows when right side matches nothing. Fixes SQL standard compliance gap. | +| Performance | Plan traversal to find projections adds negligible overhead โ€” only runs when `rightColumns` is empty (zero-row case), not on every join. | +| `columnIndex` behavior change | Qualified lookup that previously fell through to cross-table unqualified matching now returns "Unknown column" instead of "Ambiguous column". This is more correct but changes the error message for genuinely missing columns. | +| Plan structure dependency | The fix traverses `subPlan`, `left`, and projection fields of the plan tree. Changes to the planner's plan structure could break this traversal โ€” mitigated by the fallback chain (projection โ†’ table scan โ†’ give up gracefully). | + +### References + +- SQL:2016 ยง7.7 (LEFT OUTER JOIN semantics) +- ADR-0096: Case-Insensitive Identifier Resolution (related column resolution changes) diff --git a/design/adr/0099-sqlite-type-affinity-comparisons.md b/design/adr/0099-sqlite-type-affinity-comparisons.md new file mode 100644 index 0000000..e057473 --- /dev/null +++ b/design/adr/0099-sqlite-type-affinity-comparisons.md @@ -0,0 +1,50 @@ +# ADR-0099: SQLite-Compatible Type Affinity in Comparisons + +**Status:** Accepted +**Date:** 2025-07-22 +**Context:** SQLite compatibility โ€” comparisons between INTEGER and TEXT values must coerce types following SQLite's type affinity rules. + +## Problem + +DecentDB's `compareValues` performs strict type matching: if operands have different `ValueKind`s, it returns `cmp(a.kind, b.kind)` without attempting conversion. This means `42 = '42'` evaluates to `false`. + +SQLite applies **type affinity** rules (ยง3.2 of the SQLite documentation): when comparing an INTEGER/REAL value with a TEXT value, SQLite attempts to convert the TEXT to a numeric value. If conversion succeeds, it compares numerically. If conversion fails, the numeric value is always considered less than the TEXT value. + +This affects all users, not just EF Core. Any SQL migrated from SQLite (or written assuming SQLite-compatible behavior) that compares numeric columns to string literals or parameters will silently return wrong results. + +Common patterns affected: +- `WHERE id = '42'` (string literal vs INTEGER column) +- `WHERE id = @p0` where parameter is bound as TEXT +- EF Core `EF.Property(entity, "Id") == "42"` pattern +- Any ORM that passes all parameters as strings + +## Decision + +Add SQLite-compatible type affinity coercion to `compareValues` in `src/exec/exec.nim`. The rules are: + +1. **INTEGER vs TEXT**: Try to parse the TEXT as an integer. If successful, compare as integers. Otherwise, INTEGER < TEXT. +2. **FLOAT64 vs TEXT**: Try to parse the TEXT as a float. If successful, compare as floats. Otherwise, FLOAT64 < TEXT. +3. **INTEGER vs FLOAT64**: Convert the integer to float and compare. (Already partially handled by existing code paths but made explicit.) +4. **BLOB vs any non-BLOB**: BLOB is always greater than INTEGER, FLOAT, and TEXT (SQLite rule). +5. **NULL**: NULL handling is unchanged (NULL is not equal to anything, handled by existing COALESCE/IS NULL logic). + +The coercion is applied symmetrically (a vs b and b vs a). + +## Performance Impact + +- The fast path (same-kind comparison) is unchanged โ€” the `if a.kind != b.kind` check still exits early when kinds match. +- The coercion path only triggers for cross-type comparisons, which are uncommon in well-typed schemas. +- Text-to-integer parsing uses `parseBiggestInt` which is a single pass over the string โ€” negligible cost. + +## Consequences + +- `42 = '42'` now returns `true` (matches SQLite behavior) +- `42 < 'abc'` now returns `true` (numeric < non-numeric text, matches SQLite) +- `42.0 = '42.0'` now returns `true` +- Ordering of mixed-type columns now follows SQLite's collation: NULL < INTEGER/REAL < TEXT < BLOB +- No persistent format changes โ€” this only affects in-memory evaluation +- No changes to indexing or storage โ€” indexes store typed values as before + +## Risks + +- Applications that relied on strict type comparison (INTEGER โ‰  TEXT) may see different results. This is considered acceptable because the previous behavior was a compatibility bug, not a feature. diff --git a/design/adr/0100-partial-index-query-planner-exclusion.md b/design/adr/0100-partial-index-query-planner-exclusion.md new file mode 100644 index 0000000..212e4d0 --- /dev/null +++ b/design/adr/0100-partial-index-query-planner-exclusion.md @@ -0,0 +1,24 @@ +# ADR 0100: Partial Index Query Planner Exclusion + +## Status + +Accepted + +## Context + +DecentDB supports partial (filtered) indexes created with `CREATE UNIQUE INDEX ... WHERE predicate`. The index only contains rows matching the predicate. Prior to this change, the query planner's `getBtreeIndexForColumn` and `getIndexForColumn` functions returned the first matching B-tree index for a column without checking whether the index was partial. + +This caused incorrect query results: a query like `SELECT ... WHERE Type = 3` would use a partial index with predicate `WHERE Type != 3`. Since the index only contains rows where `Type != 3`, seeking `Type = 3` returned zero results โ€” a correctness bug. + +## Decision + +Skip partial indexes (those with non-empty `predicateSql`) in `getBtreeIndexForColumn` and `getIndexForColumn`. When only a partial index exists for a column, these functions return `none`, causing the query planner to fall back to a table scan. + +This is the simplest correct approach. A more sophisticated approach would evaluate whether the seek value satisfies the index predicate, but this adds complexity and the table scan fallback is correct for all cases. + +## Consequences + +- **Correctness**: Queries whose filter doesn't match a partial index predicate now return correct results via table scan. +- **Performance**: Queries that *could* benefit from a partial index (where the seek value matches the predicate) will use a table scan instead. This is a conservative trade-off favoring correctness over optimization. +- **Future work**: A smarter approach could pass the seek value to the index selection function and evaluate it against the predicate, allowing partial index use when the value is known to satisfy the predicate. +- **No format changes**: This is a query planner behavior change only; no persistent format or WAL changes. diff --git a/design/adr/0101-nodatime-member-translation-via-plugin.md b/design/adr/0101-nodatime-member-translation-via-plugin.md new file mode 100644 index 0000000..9596641 --- /dev/null +++ b/design/adr/0101-nodatime-member-translation-via-plugin.md @@ -0,0 +1,36 @@ +# ADR 0101: NodaTime Member Translation via IMemberTranslatorPlugin + +## Status + +Accepted + +## Context + +NodaTime types (`LocalDate`, `Instant`, `LocalDateTime`) are stored as integers in DecentDB (epoch days for `LocalDate`, ticks for `Instant`/`LocalDateTime`). EF Core LINQ queries that access date components (e.g., `x.ReleaseDate.Year`) need SQL translation. + +DecentDB has no built-in date functions (`strftime`, `EXTRACT`, etc.), so date component extraction must use pure integer arithmetic. + +## Decision + +Implement NodaTime member translation using: + +1. **`IMemberTranslatorPlugin` interface** โ€” the EF Core extensibility point designed for adding new member translators without replacing existing ones. Registered as `Scoped` (matching `ISqlExpressionFactory` lifetime). + +2. **Hinnant civil calendar algorithm** โ€” pure integer arithmetic to extract Year, Month, Day, and DayOfYear from epoch days. No date functions required. + +3. **Explicit CAST for type safety** โ€” The input column (typed as `LocalDate` in CLR) is cast to `long` before arithmetic to prevent CLR type propagation through `SqlBinaryExpression`, which would cause `GroupBy` translation to fail with "No coercion operator" errors. + +## Alternatives Considered + +- **Replacing `IMemberTranslatorProvider`** โ€” Caused `ArgumentNullException` in EF Core's service provider due to lifetime mismatches and hash code collisions. The plugin approach is the sanctioned EF Core extension pattern. + +- **Adding date functions to DecentDB core** โ€” Rejected: these are only needed by EF Core, not by general DecentDB users. Keeping them in the binding layer follows the principle that EF-specific logic stays in the EF provider. + +- **Client-side evaluation** โ€” Would require materializing entire result sets before grouping/filtering by date components. Unacceptable for performance. + +## Consequences + +- `LocalDate.Year`, `.Month`, `.Day`, `.DayOfYear` are translatable to SQL in LINQ queries +- `GroupBy(x => x.Date.Year)` works correctly for statistics and reporting queries +- Generated SQL is verbose (nested arithmetic) but correct and performant +- No changes to DecentDB core required diff --git a/design/adr/0102-json-scalar-functions.md b/design/adr/0102-json-scalar-functions.md new file mode 100644 index 0000000..833611d --- /dev/null +++ b/design/adr/0102-json-scalar-functions.md @@ -0,0 +1,36 @@ +# ADR 0102: JSON Scalar Functions (json_array_length, json_extract) + +## Status + +Accepted + +## Context + +The EF Core provider for DecentDB needs to support `string[]` properties stored as JSON (e.g., `["Rock","Jazz"]`). EF Core translates LINQ operations on these collections into SQL using `json_each()` table-valued functions. DecentDB does not support table-valued functions and adding them would be a significant engine change. + +The EF provider intercepts `json_each()` patterns and rewrites them to scalar function equivalents: +- `.Any()` / `.Length > 0` โ†’ `json_array_length(column) > 0` +- `.Count()` โ†’ `json_array_length(column)` +- `array[index]` โ†’ `json_extract(column, '$[N]')` + +These scalar functions must exist in the DecentDB SQL engine for the rewritten queries to execute. + +## Decision + +Add `json_array_length(text [, path])` and `json_extract(text, path)` as built-in scalar functions in `exec.nim`. These functions: + +- Parse JSON using Nim's `std/json` module +- Support JSONPath-style `$` root and `$[N]` array index notation +- Return NULL for NULL inputs (SQL NULL propagation) +- Return NULL for invalid JSON (no runtime errors for malformed data) +- `json_array_length` returns an integer count of array elements (or object keys) +- `json_extract` returns the extracted value as TEXT, INT64, FLOAT64, BOOL, or NULL depending on the JSON node type + +These are general-purpose SQL functions useful beyond EF Core โ€” any SQL user can call them. + +## Consequences + +- Two new scalar functions available to all DecentDB users via SQL +- Depends on `std/json` (Nim stdlib โ€” no new external dependency) +- JSON parsing occurs at query time per row; no pre-built JSON indexes +- Enables the EF Core provider to handle primitive collection patterns without core table-valued function support diff --git a/design/adr/0103-efcore-primitive-collection-translation.md b/design/adr/0103-efcore-primitive-collection-translation.md new file mode 100644 index 0000000..4b4f81e --- /dev/null +++ b/design/adr/0103-efcore-primitive-collection-translation.md @@ -0,0 +1,31 @@ +# ADR 0103: EF Core Primitive Collection Query Translation + +## Status + +Accepted + +## Context + +EF Core 8+ supports primitive collections โ€” properties like `string[]` stored as JSON arrays in a single column. The default translation pipeline creates `json_each()` table-valued function (TVF) calls to flatten arrays into rows for LINQ operations (`.Any()`, `.Count()`, `.Contains()`, `array[index]`). + +DecentDB does not support table-valued functions. Implementing them would require significant planner and executor changes affecting all users. + +## Decision + +Intercept and rewrite all `json_each()` patterns in the EF Core provider's `QueryableMethodTranslatingExpressionVisitor` before SQL generation. Five overrides handle the complete pattern set: + +1. **`TranslatePrimitiveCollection`**: Creates a `JsonEachExpression` intermediate representation (never reaches SQL generation) +2. **`TranslateAny`**: Detects `JsonEachExpression` in SELECT โ†’ rewrites to `json_array_length(col) > 0` +3. **`TranslateCount`**: Detects `JsonEachExpression` in SELECT โ†’ rewrites to `json_array_length(col)` +4. **`TranslateElementAtOrDefault`**: Detects `JsonEachExpression` with constant index โ†’ rewrites to `json_extract(col, '$[N]')` +5. **`TranslateContains`**: Detects `JsonEachExpression` source โ†’ rewrites to `col LIKE '%"' || @value || '"%'` + +The `JsonEachExpression` is a custom `TableValuedFunctionExpression` that serves as a marker โ€” it is always optimized away by one of the four translation overrides before reaching the SQL generator. + +## Consequences + +- All primitive collection LINQ patterns work without core TVF support +- No changes to DecentDB engine required (uses existing `json_array_length`, `json_extract`, and `LIKE`) +- `Contains` uses LIKE pattern matching (`'%"value"%'`) which is correct for simple string values but would false-match on values containing `"` characters +- The LIKE-based Contains cannot use indexes (requires full column scan) +- Future: if DecentDB adds TVF support, the provider could remove these overrides and use the default EF Core translation diff --git a/design/adr/0104-sql-planner-aggregate-and-composite-pk-fixes.md b/design/adr/0104-sql-planner-aggregate-and-composite-pk-fixes.md new file mode 100644 index 0000000..2e3d6cf --- /dev/null +++ b/design/adr/0104-sql-planner-aggregate-and-composite-pk-fixes.md @@ -0,0 +1,30 @@ +# ADR 0104: SQL Planner โ€” Aggregate ORDER BY, Composite PK, and UNION Subquery Fixes + +## Status + +Accepted + +## Context + +Three related query planner and binder issues surfaced during EF Core integration testing: + +1. **ORDER BY aggregate alias**: `SELECT COUNT(*) AS "AlbumCount" ... GROUP BY ... ORDER BY "AlbumCount"` failed because the binder replaced the alias reference with a clone of the aggregate expression. The sort node then tried to re-evaluate `COUNT(*)` against individual rows instead of the materialized aggregate output. + +2. **Composite primary key rowid seek**: Tables with composite PKs (e.g., `(AlbumId, ArtistId)`) were incorrectly treated as having a rowid alias on the first integer PK column. The planner generated a rowid seek instead of an index seek, returning wrong results. + +3. **UNION subquery derived tables**: `FROM (SELECT ... UNION SELECT ...)` failed in `subqueryTableMeta` because the function tried to derive columns from the UNION node's (empty) select items instead of recursing into the left operand. + +## Decision + +1. **Binder**: When an ORDER BY alias matches a SELECT item whose expression is an aggregate function (`COUNT`, `SUM`, `AVG`, `MIN`, `MAX`, `GROUP_CONCAT`, `STRING_AGG`), keep the ORDER BY as a column reference to the alias instead of cloning the aggregate expression. **Planner**: When building a Sort node above an Aggregate node, rewrite any ORDER BY items that contain aggregate expressions to column references matching the SELECT alias. + +2. **Planner**: `isRowidPkColumn` now counts PK columns first. If the table has more than one PK column (composite key), no column is treated as a rowid alias. + +3. **Binder**: `subqueryTableMeta` recursively follows `setOpLeft` for UNION/INTERSECT/EXCEPT nodes to find the actual SELECT items that define the result columns. Also updated `selectToCanonicalSql` to serialize `fromSubquery` and `joinSubqueries` instead of emitting empty `FROM` clauses. + +## Consequences + +- EF Core queries with `ORDER BY` on aggregated columns work correctly +- Composite PK tables use index seeks instead of incorrect rowid seeks +- Subqueries using set operations (UNION, INTERSECT, EXCEPT) work as derived tables +- No performance impact โ€” these are correctness fixes in cold (planning) paths diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 782d2f9..880aa81 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -5,6 +5,32 @@ All notable changes to DecentDB will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.4.0] - 2026-02-22 + +### Added +- **SQL Engine**: `json_array_length(json [, path])` scalar function โ€” returns the number of elements in a JSON array. Supports optional JSONPath for nested access. See ADR-0102. +- **SQL Engine**: `json_extract(json, path)` scalar function โ€” extracts a value from a JSON document using JSONPath notation (`$`, `$[N]`, `$.key`). Returns the appropriate SQL type (TEXT, INT64, FLOAT64, BOOL, NULL). See ADR-0102. +- .NET: EF Core primitive collection support โ€” `string[]` properties stored as JSON are now fully queryable via LINQ (`.Any()`, `.Count()`, `.Contains()`, `array[index]`, `.Select()`). See ADR-0103. +- .NET: NodaTime member translation plugin โ€” `Instant.InUtc().Year/Month/Day` and other date part extractions now translate to SQL expressions. See ADR-0101. +- .NET: `SqlStatementSplitter` for batch SQL execution in ADO.NET layer. +- .NET: `SqlParameterRewriter` tests covering named, positional, and mixed parameter styles. +- .NET: Primitive collection tests (11 tests covering Any, Count, Contains, ElementAt, Select, null/empty arrays). + +### Fixed +- **SQL Engine**: Case-insensitive identifier resolution following PostgreSQL semantics โ€” unquoted identifiers (lowercased by the parser) now correctly match tables, columns, and indexes created with quoted identifiers and vice versa. See ADR-0096. +- **SQL Engine**: SQLite-compatible type affinity in comparisons โ€” TEXT values are coerced to INTEGER/FLOAT for comparison operators and rowid seeks, matching SQLite behavior. See ADR-0099. +- **SQL Engine**: Partial index query planner exclusion โ€” partial indexes are no longer selected by the general query planner, preventing incorrect results when the query predicate doesn't match the index predicate. See ADR-0100. +- **SQL Engine**: Composite primary key tables no longer incorrectly use rowid seeks โ€” individual columns in composite PKs are not rowid aliases. See ADR-0104. +- **SQL Engine**: `ORDER BY` on aggregate aliases (e.g., `ORDER BY "AlbumCount"`) now correctly references the materialized aggregate output instead of re-evaluating the aggregate function. See ADR-0104. +- **SQL Engine**: `FROM (SELECT ... UNION SELECT ...)` subqueries now correctly derive column metadata from the left operand. See ADR-0104. +- **SQL Engine**: LEFT JOIN column resolution with subquery returning zero rows โ€” correctly derive column names from inner projection and pad NULLs. See ADR-0098. +- **SQL Engine**: `IN (SELECT ...)` parameter scanning now correctly finds `$N` references inside subquery SQL text. +- **Shared Library**: Disable Nim's signal handler (`-d:noSignalHandler`) and use system allocator (`-d:useMalloc`) to prevent conflicts with host runtimes (.NET, JVM). See ADR-0097. +- **Engine**: Evict stale Pager references from threadvar caches in `closeDb()` to prevent memory leaks under ARC. See ADR-0097. +- .NET: NodaTime `Instant` type mapping now uses tick-level precision (100ns ticks since Unix epoch) instead of millisecond precision, matching .NET `DateTimeOffset.Ticks` behavior. +- .NET: EF Core database creator now handles table/index creation failures gracefully during `EnsureCreated()`. +- .NET: ADO.NET `SqlParameterRewriter` correctly handles parameters in complex subqueries and multi-statement SQL. + ## [1.3.0] - 2026-02-21 ### Added diff --git a/docs/user-guide/comparison.md b/docs/user-guide/comparison.md index 2cd17c7..28ede89 100644 --- a/docs/user-guide/comparison.md +++ b/docs/user-guide/comparison.md @@ -49,7 +49,7 @@ DecentDB's current baseline includes: - CTEs: non-recursive `WITH ... AS` - Window functions: `ROW_NUMBER() OVER (...)` - Predicates: comparisons (`=`, `!=`, `<>`, `<`, `<=`, `>`, `>=`), `AND`/`OR`/`NOT`, `LIKE`/`ILIKE`, `IN`, `BETWEEN`, `EXISTS`, `IS NULL`/`IS NOT NULL` -- Scalar functions: `COALESCE`, `NULLIF`, `CAST`, `CASE`, `LENGTH`, `LOWER`, `UPPER`, `TRIM`, `REPLACE`, `SUBSTRING`/`SUBSTR`, `ABS`, `ROUND`, `CEIL`/`CEILING`, `FLOOR`, `GEN_RANDOM_UUID`, `UUID_PARSE`, `UUID_TO_STRING` +- Scalar functions: `COALESCE`, `NULLIF`, `CAST`, `CASE`, `LENGTH`, `LOWER`, `UPPER`, `TRIM`, `REPLACE`, `SUBSTRING`/`SUBSTR`, `ABS`, `ROUND`, `CEIL`/`CEILING`, `FLOOR`, `GEN_RANDOM_UUID`, `UUID_PARSE`, `UUID_TO_STRING`, `JSON_ARRAY_LENGTH`, `JSON_EXTRACT`, `PRINTF` - Operators: `+`, `-`, `*`, `/`, `||` (string concatenation) - Parameters: positional `$1, $2, ...` (Postgres-style) - `EXPLAIN` / `EXPLAIN ANALYZE` plan output diff --git a/docs/user-guide/sql-reference.md b/docs/user-guide/sql-reference.md index 1ecc260..138c7a9 100644 --- a/docs/user-guide/sql-reference.md +++ b/docs/user-guide/sql-reference.md @@ -321,6 +321,13 @@ Supported scalar functions: - `UUID_PARSE` - `UUID_TO_STRING` +**JSON:** +- `JSON_ARRAY_LENGTH(json [, path])` โ€” returns element count of a JSON array +- `JSON_EXTRACT(json, path)` โ€” extracts a value using JSONPath (`$`, `$[N]`, `$.key`) + +**Other:** +- `PRINTF(format, args...)` โ€” formatted string output (SQLite-compatible) + ```sql SELECT COALESCE(nickname, name) FROM users; SELECT NULLIF(status, 'active') FROM users; @@ -335,6 +342,10 @@ SELECT TRIM(name) || '_suffix' FROM users; SELECT CAST(id AS TEXT) FROM users; SELECT CAST('12.34' AS DECIMAL(10,2)); SELECT CASE WHEN active THEN 'on' ELSE 'off' END FROM users; +SELECT JSON_ARRAY_LENGTH('["a","b","c"]'); -- Returns 3 +SELECT JSON_EXTRACT('{"name":"Alice"}', '$.name'); -- Returns 'Alice' +SELECT JSON_EXTRACT('["x","y","z"]', '$[1]'); -- Returns 'y' +SELECT PRINTF('Hello %s, you are %d', name, age) FROM users; ``` ### Common Table Expressions (CTE) @@ -390,6 +401,8 @@ SELECT AVG(price) FROM products; SELECT MIN(created_at), MAX(created_at) FROM users; SELECT category, SUM(amount) FROM orders GROUP BY category; SELECT category, COUNT(*) FROM orders GROUP BY category HAVING COUNT(*) > 5; +SELECT GROUP_CONCAT(name, ', ') FROM users; -- Concatenate with separator +SELECT STRING_AGG(name, ', ') FROM users; -- Alias for GROUP_CONCAT ``` ### Window Functions diff --git a/examples/repro_capi_reopen.nim b/examples/repro_capi_reopen.nim new file mode 100644 index 0000000..dc6e554 --- /dev/null +++ b/examples/repro_capi_reopen.nim @@ -0,0 +1,55 @@ +import ../src/c_api +import os, strutils + +let dbPath = "/tmp/test_capi_reopen.ddb" +removeFile(dbPath) + +proc exec(p: pointer, sql: string) = + var stmt: pointer + let rc = decentdb_prepare(p, sql.cstring, addr stmt) + assert rc == 0, "prepare failed for: " & sql & " err=" & $decentdb_last_error_message(p) + discard decentdb_step(stmt) + decentdb_finalize(stmt) + +# Phase 1: Create DB with many tables through C API +block: + let p = decentdb_open(dbPath.cstring, nil) + assert p != nil, "Failed to open: " & $decentdb_last_error_message(nil) + + exec(p, """CREATE TABLE "Libraries" ("Id" INTEGER PRIMARY KEY, "Name" TEXT NOT NULL, "AlbumCount" INTEGER, "ApiKey" TEXT NOT NULL, "ArtistCount" INTEGER, "CreatedAt" TEXT, "Description" TEXT, "IsLocked" INTEGER, "LastScanAt" TEXT, "LastUpdatedAt" TEXT, "Notes" TEXT, "Path" TEXT, "SongCount" INTEGER, "SortOrder" INTEGER, "Tags" TEXT, "Type" INTEGER)""") + + exec(p, """CREATE TABLE "Artists" ("Id" INTEGER PRIMARY KEY, "Name" TEXT NOT NULL, "NameNormalized" TEXT, "LibraryId" INTEGER NOT NULL REFERENCES "Libraries"("Id"), "AlbumCount" INTEGER, "AlternateNames" TEXT, "AmgId" TEXT, "ApiKey" TEXT NOT NULL, "Biography" TEXT, "CalculatedRating" REAL, "CreatedAt" TEXT NOT NULL, "Directory" TEXT, "Tags" TEXT)""") + + # Create 50 more tables to match Melodee scale + for i in 1..50: + exec(p, "CREATE TABLE \"Table" & $i & "\" (\"Id\" INTEGER PRIMARY KEY, \"Name\" TEXT, \"Value\" TEXT)") + + exec(p, """INSERT INTO "Libraries" ("Id", "Name", "ApiKey", "CreatedAt") VALUES (1, 'TestLib', 'test-key', '2025-01-01')""") + + for i in 1..100: + exec(p, "INSERT INTO \"Table1\" (\"Id\", \"Name\", \"Value\") VALUES (" & $i & ", 'Key" & $i & "', 'Value" & $i & "')") + + discard decentdb_close(p) + echo "Phase 1: Created 52 tables + seed data, closed" + +# Phase 2: Reopen and query with INNER JOIN +block: + let p = decentdb_open(dbPath.cstring, nil) + assert p != nil, "Failed to reopen: " & $decentdb_last_error_message(nil) + + echo "About to run INNER JOIN SELECT..." + var stmt: pointer + let rc = decentdb_prepare(p, """SELECT "a"."Id", "a"."AlbumCount", "a"."AlternateNames", "a"."AmgId", "a"."ApiKey", "a"."Biography", "a"."CalculatedRating", "a"."CreatedAt", "a"."Directory", "a"."LibraryId", "a"."Name", "a"."NameNormalized", "a"."Tags", "l"."Id", "l"."AlbumCount", "l"."ApiKey", "l"."ArtistCount", "l"."CreatedAt", "l"."Description", "l"."IsLocked", "l"."LastScanAt", "l"."LastUpdatedAt", "l"."Name", "l"."Notes", "l"."Path", "l"."SongCount", "l"."SortOrder", "l"."Tags", "l"."Type" FROM "Artists" AS "a" INNER JOIN "Libraries" AS "l" ON "a"."LibraryId" = "l"."Id" WHERE "a"."Id" = 999999 LIMIT 1""".cstring, addr stmt) + + if rc != 0: + echo "Prepare failed: ", $decentdb_last_error_message(p) + else: + let stepRc = decentdb_step(stmt) + echo "Step result: ", stepRc + decentdb_finalize(stmt) + echo "Query completed successfully" + + discard decentdb_close(p) + echo "Done" + +removeFile(dbPath) diff --git a/examples/repro_join_reopen.nim b/examples/repro_join_reopen.nim new file mode 100644 index 0000000..7f3b661 --- /dev/null +++ b/examples/repro_join_reopen.nim @@ -0,0 +1,89 @@ +import ../src/engine + +import os + +let dbPath = "/tmp/test_join_crash.ddb" +removeFile(dbPath) + +# Create and populate +block: + let res = openDb(dbPath) + assert res.ok, "Failed to open: " & res.err.message + let db = res.value + + discard execSql(db, """ + CREATE TABLE "Artists" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "NameNormalized" TEXT, + "LibraryId" INTEGER NOT NULL, + "AlbumCount" INTEGER, + "AlternateNames" TEXT, + "AmgId" TEXT, + "ApiKey" TEXT NOT NULL, + "Biography" TEXT, + "CreatedAt" TEXT NOT NULL, + "Directory" TEXT, + "Tags" TEXT + ) + """) + + discard execSql(db, """ + CREATE TABLE "Libraries" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "AlbumCount" INTEGER, + "ApiKey" TEXT NOT NULL, + "ArtistCount" INTEGER, + "CreatedAt" TEXT, + "Description" TEXT, + "IsLocked" INTEGER, + "LastScanAt" TEXT, + "LastUpdatedAt" TEXT, + "Notes" TEXT, + "Path" TEXT, + "SongCount" INTEGER, + "SortOrder" INTEGER, + "Tags" TEXT, + "Type" INTEGER + ) + """) + + discard execSql(db, """ + INSERT INTO "Libraries" ("Id", "Name", "ApiKey", "CreatedAt") + VALUES (1, 'TestLib', 'test-key', '2025-01-01') + """) + + discard closeDb(db) + echo "DB created and closed" + +# Reopen and query +block: + let res = openDb(dbPath) + assert res.ok, "Failed to reopen: " & res.err.message + let db = res.value + + echo "About to run INNER JOIN SELECT..." + let r = execSql(db, """ + SELECT "a"."Id", "a"."AlbumCount", "a"."AlternateNames", "a"."AmgId", "a"."ApiKey", + "a"."Biography", "a"."CreatedAt", "a"."Directory", "a"."LibraryId", + "a"."Name", "a"."NameNormalized", "a"."Tags", + "l"."Id", "l"."AlbumCount", "l"."ApiKey", "l"."ArtistCount", + "l"."CreatedAt", "l"."Description", "l"."IsLocked", "l"."LastScanAt", + "l"."LastUpdatedAt", "l"."Name", "l"."Notes", "l"."Path", + "l"."SongCount", "l"."SortOrder", "l"."Tags", "l"."Type" + FROM "Artists" AS "a" + INNER JOIN "Libraries" AS "l" ON "a"."LibraryId" = "l"."Id" + WHERE "a"."Id" = 999999 + LIMIT 1 + """) + + if r.ok: + echo "Query OK, rows: ", r.value.len + else: + echo "Query failed: ", r.err.message + + discard closeDb(db) + echo "Done" + +removeFile(dbPath) diff --git a/examples/repro_many_tables_reopen.nim b/examples/repro_many_tables_reopen.nim new file mode 100644 index 0000000..cf82c25 --- /dev/null +++ b/examples/repro_many_tables_reopen.nim @@ -0,0 +1,133 @@ +import ../src/engine +import os, strutils + +let dbPath = "/tmp/test_many_tables.ddb" +removeFile(dbPath) + +# Create DB with 52 tables (mimicking Melodee schema) +block: + let res = openDb(dbPath) + assert res.ok, "Failed to open: " & res.err.message + let db = res.value + + # Create a Library table + discard execSql(db, """ + CREATE TABLE "Libraries" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "AlbumCount" INTEGER, + "ApiKey" TEXT NOT NULL, + "ArtistCount" INTEGER, + "CreatedAt" TEXT, + "Description" TEXT, + "IsLocked" INTEGER, + "LastScanAt" TEXT, + "LastUpdatedAt" TEXT, + "Notes" TEXT, + "Path" TEXT, + "SongCount" INTEGER, + "SortOrder" INTEGER, + "Tags" TEXT, + "Type" INTEGER + ) + """) + + # Create Artists table with FK to Libraries + discard execSql(db, """ + CREATE TABLE "Artists" ( + "Id" INTEGER PRIMARY KEY, + "Name" TEXT NOT NULL, + "NameNormalized" TEXT, + "LibraryId" INTEGER NOT NULL, + "AlbumCount" INTEGER, + "AlternateNames" TEXT, + "AmgId" TEXT, + "ApiKey" TEXT NOT NULL, + "Biography" TEXT, + "CalculatedRating" REAL, + "CreatedAt" TEXT NOT NULL, + "DeezerId" TEXT, + "Description" TEXT, + "Directory" TEXT, + "DiscogsId" TEXT, + "ImageCount" INTEGER, + "IsLocked" INTEGER, + "ItunesId" TEXT, + "LastFmId" TEXT, + "LastMetaDataUpdatedAt" TEXT, + "LastPlayedAt" TEXT, + "LastUpdatedAt" TEXT, + "MetaDataStatus" INTEGER, + "MusicBrainzId" TEXT, + "Notes" TEXT, + "PlayedCount" INTEGER, + "RealName" TEXT, + "Roles" TEXT, + "SongCount" INTEGER, + "SortName" TEXT, + "SortOrder" INTEGER, + "SpotifyId" TEXT, + "Tags" TEXT, + "WikiDataId" TEXT, + FOREIGN KEY ("LibraryId") REFERENCES "Libraries" ("Id") + ) + """) + + # Create 50 more dummy tables to match the Melodee schema count + for i in 1..50: + let name = "Table" & $i + let sql = "CREATE TABLE \"" & name & "\" (\"Id\" INTEGER PRIMARY KEY, \"Name\" TEXT, \"Value\" TEXT, \"Ref\" INTEGER)" + let r = execSql(db, sql) + assert r.ok, "Create " & name & " failed: " & r.err.message + + # Insert seed data - Library + discard execSql(db, """ + INSERT INTO "Libraries" ("Id", "Name", "ApiKey", "CreatedAt") + VALUES (1, 'TestLib', 'test-key', '2025-01-01') + """) + + # Insert 312 rows of seed data into a dummy table (mimicking Settings) + for i in 1..312: + let sql = "INSERT INTO \"Table1\" (\"Id\", \"Name\", \"Value\") VALUES (" & $i & ", 'Key" & $i & "', 'Value" & $i & "')" + let r = execSql(db, sql) + assert r.ok, "Insert row " & $i & " failed: " & r.err.message + + discard closeDb(db) + echo "DB created with 52 tables and 312 seed rows, closed" + +# Reopen and run INNER JOIN SELECT +block: + let res = openDb(dbPath) + assert res.ok, "Failed to reopen: " & res.err.message + let db = res.value + + echo "About to run INNER JOIN SELECT..." + let r = execSql(db, """ + SELECT "a"."Id", "a"."AlbumCount", "a"."AlternateNames", "a"."AmgId", "a"."ApiKey", + "a"."Biography", "a"."CalculatedRating", "a"."CreatedAt", "a"."DeezerId", + "a"."Description", "a"."Directory", "a"."DiscogsId", "a"."ImageCount", + "a"."IsLocked", "a"."ItunesId", "a"."LastFmId", "a"."LastMetaDataUpdatedAt", + "a"."LastPlayedAt", "a"."LastUpdatedAt", "a"."LibraryId", + "a"."MetaDataStatus", "a"."MusicBrainzId", "a"."Name", "a"."NameNormalized", + "a"."Notes", "a"."PlayedCount", "a"."RealName", "a"."Roles", + "a"."SongCount", "a"."SortName", "a"."SortOrder", "a"."SpotifyId", + "a"."Tags", "a"."WikiDataId", + "l"."Id", "l"."AlbumCount", "l"."ApiKey", "l"."ArtistCount", + "l"."CreatedAt", "l"."Description", "l"."IsLocked", "l"."LastScanAt", + "l"."LastUpdatedAt", "l"."Name", "l"."Notes", "l"."Path", + "l"."SongCount", "l"."SortOrder", "l"."Tags", "l"."Type" + FROM "Artists" AS "a" + INNER JOIN "Libraries" AS "l" ON "a"."LibraryId" = "l"."Id" + WHERE "a"."Id" = 999999 + LIMIT 1 + """) + + if r.ok: + echo "Query OK, rows: ", r.value.len + else: + echo "Query failed: ", r.err.message + + discard closeDb(db) + echo "Done" + +removeFile(dbPath) diff --git a/src/btree/btree.nim b/src/btree/btree.nim index a794fa3..7be57d2 100644 --- a/src/btree/btree.nim +++ b/src/btree/btree.nim @@ -48,6 +48,19 @@ var gAppendCache {.threadvar.}: Table[PageId, AppendCacheEntry] var gLastCacheRoot {.threadvar.}: PageId var gLastCachePtr {.threadvar.}: ptr AppendCacheEntry +proc evictPagerFromAppendCache*(pager: Pager) = + ## Remove all append-cache entries belonging to a specific pager. + ## Call this when closing a database to prevent leaked Pager refs. + if gAppendCache.len == 0: + return + var toDelete: seq[PageId] + for root, entry in gAppendCache: + if entry.pager == pager: + toDelete.add(root) + for root in toDelete: + gAppendCache.del(root) + gLastCachePtr = nil + proc encodeVarintToBuf*(v: uint64, buf: var array[10, byte]): int {.inline.} = ## Encode a varint into a stack-allocated buffer, returning the number of bytes written. var x = v diff --git a/src/c_api.nim b/src/c_api.nim index 86d91a7..78dfd93 100644 --- a/src/c_api.nim +++ b/src/c_api.nim @@ -329,6 +329,10 @@ proc findMaxParam(stmt: Statement): int = if e.args.len == 1 and e.args[0].kind == ekLiteral and e.args[0].value.kind == svString: scanSqlForParams(e.args[0].value.strVal) + elif e.funcName == "IN_SUBQUERY": + if e.args.len == 2 and e.args[1].kind == ekLiteral and + e.args[1].value.kind == svString: + scanSqlForParams(e.args[1].value.strVal) for a in e.args: walk(a) of ekInList: walk(e.inExpr); (for a in e.inList: walk(a)) else: discard diff --git a/src/catalog/catalog.nim b/src/catalog/catalog.nim index c985dc6..90bb8bd 100644 --- a/src/catalog/catalog.nim +++ b/src/catalog/catalog.nim @@ -444,7 +444,7 @@ proc rebuildReverseFkCache(catalog: Catalog) = for col in meta.columns: if col.refTable.len == 0 or col.refColumn.len == 0: continue - let key = (col.refTable, col.refColumn) + let key = (normalizedObjectName(col.refTable), normalizedObjectName(col.refColumn)) if not catalog.reverseFkRefs.hasKey(key): catalog.reverseFkRefs[key] = @[] catalog.reverseFkRefs[key].add(ReferencingChildFk( @@ -523,11 +523,11 @@ proc initCatalog*(pager: Pager): Result[Catalog] = if maxKeyRes.ok: if table.nextRowId <= maxKeyRes.value: table.nextRowId = maxKeyRes.value + 1 - catalog.tables[record.table.name] = table + catalog.tables[normalizedObjectName(record.table.name)] = table of crIndex: - catalog.indexes[record.index.name] = record.index + catalog.indexes[normalizedObjectName(record.index.name)] = record.index of crView: - catalog.views[record.view.name] = record.view + catalog.views[normalizedObjectName(record.view.name)] = record.view of crTrigger: catalog.triggers[triggerMetaKey(record.trigger.table, record.trigger.name)] = record.trigger rebuildDependentViewsIndex(catalog) @@ -567,20 +567,20 @@ proc allTrigramDeltas*(catalog: Catalog): seq[((string, uint32), TrigramDelta)] proc updateTableMeta*(catalog: Catalog, table: TableMeta) = ## Updates the in-memory metadata for a table without persisting to disk. ## Use with caution: changes will be lost on crash if not followed by saveTable eventually. - catalog.tables[table.name] = table + catalog.tables[normalizedObjectName(table.name)] = table proc updateTableMetaFast*(catalog: Catalog, tableName: string, nextRowId: uint64, rootPage: PageId) {.inline.} = ## Updates only nextRowId and rootPage in the in-memory table metadata. ## Avoids copying the entire TableMeta struct when only these fields change. - catalog.tables.withValue(tableName, entry): + catalog.tables.withValue(normalizedObjectName(tableName), entry): entry.nextRowId = nextRowId entry.rootPage = rootPage proc saveTable*(catalog: Catalog, pager: Pager, table: TableMeta): Result[Void] = var rebuildFk = true - if catalog.tables.hasKey(table.name): - rebuildFk = catalog.tables[table.name].columns != table.columns - catalog.tables[table.name] = table + if catalog.tables.hasKey(normalizedObjectName(table.name)): + rebuildFk = catalog.tables[normalizedObjectName(table.name)].columns != table.columns + catalog.tables[normalizedObjectName(table.name)] = table if rebuildFk: rebuildReverseFkCache(catalog) let key = uint64(crc32c(stringToBytes("table:" & table.name))) @@ -604,19 +604,19 @@ proc saveTable*(catalog: Catalog, pager: Pager, table: TableMeta): Result[Void] okVoid() proc getTable*(catalog: Catalog, name: string): Result[TableMeta] = - if not catalog.tables.hasKey(name): + if not catalog.tables.hasKey(normalizedObjectName(name)): return err[TableMeta](ERR_SQL, "Table not found", name) - ok(catalog.tables[name]) + ok(catalog.tables[normalizedObjectName(name)]) proc getTablePtr*(catalog: Catalog, name: string): ptr TableMeta = ## Returns a mutable pointer into the catalog table map. Caller must not ## hold this across operations that could rehash catalog.tables. - catalog.tables.withValue(name, v): + catalog.tables.withValue(normalizedObjectName(name), v): return addr v[] return nil proc createIndexMeta*(catalog: Catalog, index: IndexMeta): Result[Void] = - catalog.indexes[index.name] = index + catalog.indexes[normalizedObjectName(index.name)] = index let key = uint64(crc32c(stringToBytes("index:" & index.name))) let record = makeIndexRecord(index.name, index.table, index.columns, index.rootPage, index.kind, index.unique, index.predicateSql) let insertRes = insert(catalog.catalogTree, key, record) @@ -629,7 +629,7 @@ proc createIndexMeta*(catalog: Catalog, index: IndexMeta): Result[Void] = okVoid() proc saveIndexMeta*(catalog: Catalog, index: IndexMeta): Result[Void] = - catalog.indexes[index.name] = index + catalog.indexes[normalizedObjectName(index.name)] = index let key = uint64(crc32c(stringToBytes("index:" & index.name))) discard delete(catalog.catalogTree, key) let record = makeIndexRecord(index.name, index.table, index.columns, index.rootPage, index.kind, index.unique, index.predicateSql) @@ -643,15 +643,16 @@ proc saveIndexMeta*(catalog: Catalog, index: IndexMeta): Result[Void] = okVoid() proc createViewMeta*(catalog: Catalog, view: ViewMeta): Result[Void] = - if catalog.views.hasKey(view.name): + let normName = normalizedObjectName(view.name) + if catalog.views.hasKey(normName): return err[Void](ERR_SQL, "View already exists", view.name) - catalog.views[view.name] = view + catalog.views[normName] = view rebuildDependentViewsIndex(catalog) let key = uint64(crc32c(stringToBytes("view:" & view.name))) let record = makeViewRecord(view.name, view.sqlText, view.columnNames, view.dependencies) let insertRes = insert(catalog.catalogTree, key, record) if not insertRes.ok: - catalog.views.del(view.name) + catalog.views.del(normName) rebuildDependentViewsIndex(catalog) return err[Void](insertRes.err.code, insertRes.err.message, insertRes.err.context) if catalog.catalogTree.root != catalog.catalogTree.pager.header.rootCatalog: @@ -659,7 +660,7 @@ proc createViewMeta*(catalog: Catalog, view: ViewMeta): Result[Void] = okVoid() proc saveViewMeta*(catalog: Catalog, view: ViewMeta): Result[Void] = - catalog.views[view.name] = view + catalog.views[normalizedObjectName(view.name)] = view rebuildDependentViewsIndex(catalog) let key = uint64(crc32c(stringToBytes("view:" & view.name))) let record = makeViewRecord(view.name, view.sqlText, view.columnNames, view.dependencies) @@ -678,27 +679,30 @@ proc saveViewMeta*(catalog: Catalog, view: ViewMeta): Result[Void] = okVoid() proc getView*(catalog: Catalog, name: string): Result[ViewMeta] = - if not catalog.views.hasKey(name): + let normName = normalizedObjectName(name) + if not catalog.views.hasKey(normName): return err[ViewMeta](ERR_SQL, "View not found", name) - ok(catalog.views[name]) + ok(catalog.views[normName]) proc createTriggerMeta*(catalog: Catalog, trigger: TriggerMeta): Result[Void] proc dropTrigger*(catalog: Catalog, tableName: string, triggerName: string): Result[Void] proc dropView*(catalog: Catalog, name: string): Result[Void] = - if not catalog.views.hasKey(name): + let normName = normalizedObjectName(name) + if not catalog.views.hasKey(normName): return err[Void](ERR_SQL, "View not found", name) + let originalName = catalog.views[normName].name var triggerNames: seq[string] = @[] for _, trigger in catalog.triggers: - if normalizedObjectName(trigger.table) == normalizedObjectName(name): + if normalizedObjectName(trigger.table) == normName: triggerNames.add(trigger.name) for triggerName in triggerNames: let dropTrigRes = dropTrigger(catalog, name, triggerName) if not dropTrigRes.ok: return dropTrigRes - catalog.views.del(name) + catalog.views.del(normName) rebuildDependentViewsIndex(catalog) - let key = uint64(crc32c(stringToBytes("view:" & name))) + let key = uint64(crc32c(stringToBytes("view:" & originalName))) let delRes = delete(catalog.catalogTree, key) if not delRes.ok: return err[Void](delRes.err.code, delRes.err.message, delRes.err.context) @@ -707,18 +711,20 @@ proc dropView*(catalog: Catalog, name: string): Result[Void] = okVoid() proc renameView*(catalog: Catalog, oldName: string, newName: string): Result[Void] = - if not catalog.views.hasKey(oldName): + let normOldName = normalizedObjectName(oldName) + let normNewName = normalizedObjectName(newName) + if not catalog.views.hasKey(normOldName): return err[Void](ERR_SQL, "View not found", oldName) - if catalog.views.hasKey(newName): + if catalog.views.hasKey(normNewName): return err[Void](ERR_SQL, "View already exists", newName) - var view = catalog.views[oldName] - let oldKey = uint64(crc32c(stringToBytes("view:" & oldName))) + var view = catalog.views[normOldName] + let oldKey = uint64(crc32c(stringToBytes("view:" & view.name))) let delRes = delete(catalog.catalogTree, oldKey) if not delRes.ok: return err[Void](delRes.err.code, delRes.err.message, delRes.err.context) - catalog.views.del(oldName) + catalog.views.del(normOldName) view.name = newName - catalog.views[newName] = view + catalog.views[normNewName] = view var triggerMetas: seq[TriggerMeta] = @[] for _, trigger in catalog.triggers: if normalizedObjectName(trigger.table) == normalizedObjectName(oldName): @@ -774,7 +780,7 @@ proc dropTrigger*(catalog: Catalog, tableName: string, triggerName: string): Res okVoid() proc referencingChildren*(catalog: Catalog, tableName: string, columnName: string): seq[ReferencingChildFk] = - let key = (tableName, columnName) + let key = (normalizedObjectName(tableName), normalizedObjectName(columnName)) if catalog.reverseFkRefs.hasKey(key): return catalog.reverseFkRefs[key] @[] @@ -814,19 +820,21 @@ proc listDependentViews*(catalog: Catalog, objectName: string): seq[string] = result.sort() proc dropTable*(catalog: Catalog, name: string): Result[Void] = - if not catalog.tables.hasKey(name): + let normName = normalizedObjectName(name) + if not catalog.tables.hasKey(normName): return err[Void](ERR_SQL, "Table not found", name) + let originalName = catalog.tables[normName].name var triggerNames: seq[string] = @[] for _, trigger in catalog.triggers: - if normalizedObjectName(trigger.table) == normalizedObjectName(name): + if normalizedObjectName(trigger.table) == normName: triggerNames.add(trigger.name) for triggerName in triggerNames: let dropTrigRes = dropTrigger(catalog, name, triggerName) if not dropTrigRes.ok: return dropTrigRes - catalog.tables.del(name) + catalog.tables.del(normName) rebuildReverseFkCache(catalog) - let key = uint64(crc32c(stringToBytes("table:" & name))) + let key = uint64(crc32c(stringToBytes("table:" & originalName))) let delRes = delete(catalog.catalogTree, key) if not delRes.ok: return err[Void](delRes.err.code, delRes.err.message, delRes.err.context) @@ -837,10 +845,12 @@ proc dropTable*(catalog: Catalog, name: string): Result[Void] = okVoid() proc dropIndex*(catalog: Catalog, name: string): Result[Void] = - if not catalog.indexes.hasKey(name): + let normName = normalizedObjectName(name) + if not catalog.indexes.hasKey(normName): return err[Void](ERR_SQL, "Index not found", name) - catalog.indexes.del(name) - let key = uint64(crc32c(stringToBytes("index:" & name))) + let originalName = catalog.indexes[normName].name + catalog.indexes.del(normName) + let key = uint64(crc32c(stringToBytes("index:" & originalName))) let delRes = delete(catalog.catalogTree, key) if not delRes.ok: return err[Void](delRes.err.code, delRes.err.message, delRes.err.context) @@ -851,38 +861,46 @@ proc dropIndex*(catalog: Catalog, name: string): Result[Void] = okVoid() proc getBtreeIndexForColumn*(catalog: Catalog, table: string, column: string): Option[IndexMeta] = + let normTable = normalizedObjectName(table) for _, idx in catalog.indexes: - if idx.table == table and idx.columns.len == 1 and idx.columns[0] == column and idx.kind == ikBtree: + if normalizedObjectName(idx.table) == normTable and idx.columns.len == 1 and normalizedObjectName(idx.columns[0]) == normalizedObjectName(column) and idx.kind == ikBtree and idx.predicateSql.len == 0: return some(idx) none(IndexMeta) proc getIndexForColumn*(catalog: Catalog, table: string, column: string, kind: IndexKind, requireUnique: bool = false): Option[IndexMeta] = ## Returns any single-column index that semantically satisfies the requested signature. ## If requireUnique is true, only unique indexes satisfy. + ## Partial indexes (with predicateSql) are excluded โ€” they may not cover all rows. + let normTable = normalizedObjectName(table) + let normColumn = normalizedObjectName(column) for _, idx in catalog.indexes: - if idx.table != table or idx.columns.len != 1 or idx.columns[0] != column or idx.kind != kind: + if normalizedObjectName(idx.table) != normTable or idx.columns.len != 1 or normalizedObjectName(idx.columns[0]) != normColumn or idx.kind != kind: continue if requireUnique and not idx.unique: continue + if idx.predicateSql.len > 0: + continue return some(idx) none(IndexMeta) proc getTrigramIndexForColumn*(catalog: Catalog, table: string, column: string): Option[IndexMeta] = + let normTable = normalizedObjectName(table) for _, idx in catalog.indexes: - if idx.table == table and idx.columns.len == 1 and idx.columns[0] == column and idx.kind == ikTrigram: + if normalizedObjectName(idx.table) == normTable and idx.columns.len == 1 and normalizedObjectName(idx.columns[0]) == normalizedObjectName(column) and idx.kind == ikTrigram: return some(idx) none(IndexMeta) proc getIndexByName*(catalog: Catalog, name: string): Option[IndexMeta] = - if catalog.indexes.hasKey(name): - return some(catalog.indexes[name]) + let normName = normalizedObjectName(name) + if catalog.indexes.hasKey(normName): + return some(catalog.indexes[normName]) none(IndexMeta) proc hasTableName*(catalog: Catalog, name: string): bool = - catalog.tables.hasKey(name) + catalog.tables.hasKey(normalizedObjectName(name)) proc hasViewName*(catalog: Catalog, name: string): bool = - catalog.views.hasKey(name) + catalog.views.hasKey(normalizedObjectName(name)) proc hasTableOrViewName*(catalog: Catalog, name: string): bool = - catalog.tables.hasKey(name) or catalog.views.hasKey(name) + catalog.tables.hasKey(normalizedObjectName(name)) or catalog.views.hasKey(normalizedObjectName(name)) diff --git a/src/engine.nim b/src/engine.nim index 13d9b9c..c6238c4 100644 --- a/src/engine.nim +++ b/src/engine.nim @@ -534,13 +534,17 @@ proc enforceUnique(catalog: Catalog, pager: Pager, table: TableMeta, values: seq return err[Void](ERR_CONSTRAINT, "UNIQUE constraint failed", table.name & "." & col.name) # Check unique indexes (composite and single-column CREATE UNIQUE INDEX) for _, idx in catalog.indexes: - if idx.table != table.name or not idx.unique: + if idx.table.toLowerAscii() != table.name.toLowerAscii() or not idx.unique: continue + # Skip rows excluded by partial index predicate + if idx.predicateSql.len > 0: + if not shouldIncludeInIndex(table, idx, values): + continue # Skip single-column indexes already covered by inline col.unique above if idx.columns.len == 1: var coveredInline = false for col in table.columns: - if col.name == idx.columns[0] and col.unique: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii() and col.unique: coveredInline = true break if coveredInline: @@ -678,7 +682,7 @@ proc findConflictRowidOnTarget( for colName in targetCols: var idx = -1 for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): idx = i break if idx < 0: @@ -720,9 +724,15 @@ proc findConflictRowidOnTarget( var matchedIndex: Option[IndexMeta] = none(IndexMeta) for _, idx in catalog.indexes: - if idx.table == table.name and idx.unique and idx.columns == targetCols: - matchedIndex = some(idx) - break + if idx.table.toLowerAscii() == table.name.toLowerAscii() and idx.unique and idx.columns.len == targetCols.len: + var colsMatch = true + for i in 0..= 0: @@ -3070,7 +3097,7 @@ proc tryFastPkUpdate(db: Db, bound: Statement, params: seq[Value]): Result[Optio for colName, expr in bound.assignments: var idx = -1 for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): idx = i break if idx >= 0: @@ -3214,7 +3241,7 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: return err[int64](ERR_SQL, "Cannot drop table with dependent views", bound.dropTableName) var toDrop: seq[string] = @[] for name, idx in db.catalog.indexes: - if idx.table == bound.dropTableName: + if idx.table.toLowerAscii() == bound.dropTableName.toLowerAscii(): toDrop.add(name) for idxName in toDrop: discard db.catalog.dropIndex(idxName) @@ -3246,7 +3273,7 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: for colName in bound.columnNames: var found = false for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): colIndices.add(i) found = true break @@ -3255,11 +3282,20 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: let rowsRes = scanTable(db.pager, table) if not rowsRes.ok: return err[int64](rowsRes.err.code, rowsRes.err.message, rowsRes.err.context) + var qualCols2: seq[string] = @[] + if bound.indexPredicate != nil: + for col in table.columns: + qualCols2.add(bound.indexTableName & "." & col.name) if colIndices.len == 1: let colKind = table.columns[colIndices[0]].kind var seen: Table[uint64, bool] = initTable[uint64, bool]() var seenVals: Table[uint64, seq[string]] = initTable[uint64, seq[string]]() for row in rowsRes.value: + if bound.indexPredicate != nil: + let evalRow = Row(rowid: row.rowid, columns: qualCols2, values: row.values) + let predRes = evalExpr(evalRow, bound.indexPredicate, @[]) + if predRes.ok and not valueToBool(predRes.value): + continue if row.values[colIndices[0]].kind == vkNull: continue let key = indexKeyFromValue(row.values[colIndices[0]]) @@ -3278,6 +3314,11 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: else: var seen: Table[uint64, seq[seq[Value]]] = initTable[uint64, seq[seq[Value]]]() for row in rowsRes.value: + if bound.indexPredicate != nil: + let evalRow = Row(rowid: row.rowid, columns: qualCols2, values: row.values) + let predRes = evalExpr(evalRow, bound.indexPredicate, @[]) + if predRes.ok and not valueToBool(predRes.value): + continue var hasNull = false var vals: seq[Value] = @[] for ci in colIndices: @@ -3450,15 +3491,15 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: bound.insertColumns var colIndex = initTable[string, int]() for ci, col in table.columns: - colIndex[col.name] = ci + colIndex[col.name.toLowerAscii()] = ci affected = 0 for row in selRows.value: var vals = newSeq[Value](table.columns.len) for vi in 0 ..< vals.len: vals[vi] = Value(kind: vkNull) for ci, colName in targetCols: - if ci < row.values.len and colIndex.hasKey(colName): - let idx = colIndex[colName] + if ci < row.values.len and colIndex.hasKey(colName.toLowerAscii()): + let idx = colIndex[colName.toLowerAscii()] let checked = typeCheckValue(table.columns[idx], row.values[ci]) if not checked.ok: return err[int64](checked.err.code, checked.err.message, checked.err.context) @@ -3513,7 +3554,7 @@ proc execPreparedNonSelect*(db: Db, bound: Statement, params: seq[Value], plan: for colName, expr in bound.assignments: var idx = -1 for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): idx = i break if idx >= 0: @@ -4058,7 +4099,7 @@ proc bulkLoad*(db: Db, tableName: string, rows: seq[seq[Value]], options: BulkLo if options.disableIndexes: for _, idx in db.catalog.indexes: - if idx.table == tableName: + if idx.table.toLowerAscii() == tableName.toLowerAscii(): let rebuildRes = rebuildIndex(db.pager, db.catalog, idx) if not rebuildRes.ok: return rebuildRes @@ -4083,6 +4124,12 @@ proc closeDb*(db: Db): Result[Void] = if not db.isOpen: return okVoid() + # Evict threadvar cache entries that reference this database's pager, + # preventing leaked Pager refs from accumulating across open/close cycles. + evictPagerFromAppendCache(db.pager) + evictPagerFromReusableBTree(db.pager) + evictPagerFromEvalCache(db.pager) + # Flush any pending trigram deltas before closing (ensure clean shutdown durability) # We ignore errors here to prioritize closing, but log them in a real system if db.catalog.trigramDeltas.len > 0: diff --git a/src/exec/exec.nim b/src/exec/exec.nim index 6c2e140..c0fae75 100644 --- a/src/exec/exec.nim +++ b/src/exec/exec.nim @@ -10,6 +10,7 @@ import std/monotimes import std/times import std/math import std/sysrand +import std/json except `%` import ../errors import ../sql/sql import ../sql/binder @@ -38,6 +39,13 @@ var gEvalContextDepth* {.threadvar.}: int var gEvalPager* {.threadvar.}: Pager var gEvalCatalog* {.threadvar.}: Catalog +proc evictPagerFromEvalCache*(pager: Pager) = + ## Release eval-context caches if they reference the given pager. + if gEvalPager == pager: + gEvalPager = nil + if gEvalCatalog != nil and gEvalCatalog.catalogTree != nil and gEvalCatalog.catalogTree.pager == pager: + gEvalCatalog = nil + proc decimalToString(val: int64, scale: uint8): string = var s = $val if scale == 0: return s @@ -745,6 +753,23 @@ proc execPlan*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Value]): proc evalExpr*(row: Row, expr: Expr, params: seq[Value]): Result[Value] proc valueToBool*(value: Value): bool +# Value kind classification helpers (used by rowid seek coercion and compareValues) +proc isIntKind(k: ValueKind): bool {.inline.} = + k in {vkInt64, vkInt0, vkInt1} + +proc isTextKind(k: ValueKind): bool {.inline.} = + k in {vkText, vkTextCompressed, vkTextOverflow, vkTextCompressedOverflow} + +proc isBlobKind(k: ValueKind): bool {.inline.} = + k in {vkBlob, vkBlobCompressed, vkBlobOverflow, vkBlobCompressedOverflow} + +proc getIntVal(v: Value): int64 {.inline.} = + case v.kind + of vkInt64: v.int64Val + of vkInt0: 0'i64 + of vkInt1: 1'i64 + else: 0'i64 + proc openRowCursor*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Value]): Result[RowCursor] = ## Build a forward-only cursor for a query plan. ## @@ -852,7 +877,15 @@ proc openRowCursor*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Valu let valueRes = evalExpr(Row(), plan.valueExpr, params) if not valueRes.ok: return err[RowCursor](valueRes.err.code, valueRes.err.message, valueRes.err.context) - if valueRes.value.kind != vkInt64: + var seekVal = valueRes.value + # SQLite type affinity: coerce text to int for rowid seeks + if isTextKind(seekVal.kind): + let s = valueToString(seekVal) + try: + seekVal = Value(kind: vkInt64, int64Val: int64(parseBiggestInt(s))) + except ValueError: + return err[RowCursor](ERR_SQL, "Rowid seek expects INT64") + if seekVal.kind != vkInt64 and seekVal.kind != vkInt0 and seekVal.kind != vkInt1: return err[RowCursor](ERR_SQL, "Rowid seek expects INT64") let tableRes = catalog.getTable(plan.table) if not tableRes.ok: @@ -862,7 +895,7 @@ proc openRowCursor*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Valu var cols: seq[string] = @[] for col in table.columns: cols.add(prefix & "." & col.name) - let targetRowId = cast[uint64](valueRes.value.int64Val) + let targetRowId = cast[uint64](getIntVal(seekVal)) var done = false let c = RowCursor( columns: cols, @@ -1230,7 +1263,7 @@ proc tryCountNoRowsFast*(pager: Pager, catalog: Catalog, plan: Plan, params: seq var colIndex = -1 for i, col in table.columns: - if col.name == plan.column: + if col.name.toLowerAscii() == plan.column.toLowerAscii(): colIndex = i break if colIndex < 0: @@ -1322,7 +1355,7 @@ proc tryCountNoRowsFast*(pager: Pager, catalog: Catalog, plan: Plan, params: seq var colIndex = -1 for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): colIndex = i break if colIndex < 0: @@ -1377,22 +1410,34 @@ proc estimateRowBytes*(row: Row): int = result += varintLen(uint64(payloadLen)) + payloadLen proc columnIndex*(row: Row, table: string, name: string): Result[int] = + let normName = name.toLowerAscii() if table.len > 0: - let key = table & "." & name + let key = (table & "." & name).toLowerAscii() for i, col in row.columns: - if col == key: + if col.toLowerAscii() == key: return ok(i) - # Qualified lookup failed (e.g. post-projection columns lack alias prefix); - # fall through to unqualified matching below. + # Qualified lookup failed โ€” columns may lack this alias prefix. Try + # matching only columns that are unqualified or share the same table prefix + # to avoid false ambiguity across different tables. + var matches: seq[int] = @[] + for i, col in row.columns: + let normCol = col.toLowerAscii() + if normCol == normName: + matches.add(i) + if matches.len == 1: + return ok(matches[0]) + if matches.len == 0: + return err[int](ERR_SQL, "Unknown column", table & "." & name) + return err[int](ERR_SQL, "Ambiguous column", name) var matches: seq[int] = @[] for i, col in row.columns: - if col == name or col.endsWith("." & name): + let normCol = col.toLowerAscii() + if normCol == normName or normCol.endsWith("." & normName): matches.add(i) if matches.len == 1: return ok(matches[0]) if matches.len == 0: - let ctx = if table.len > 0: table & "." & name else: name - return err[int](ERR_SQL, "Unknown column", ctx) + return err[int](ERR_SQL, "Unknown column", name) err[int](ERR_SQL, "Ambiguous column", name) proc evalLiteral(value: SqlValue): Value = @@ -1584,25 +1629,97 @@ proc truthOr(left: TruthValue, right: TruthValue): TruthValue = return tvUnknown tvFalse +proc compareBytes(a, b: Value): int {.inline.} = + let lenA = a.bytes.len + let lenB = b.bytes.len + let minLen = min(lenA, lenB) + if minLen > 0: + let c = cmpMem(unsafeAddr a.bytes[0], unsafeAddr b.bytes[0], minLen) + if c != 0: return c + cmp(lenA, lenB) + proc compareValues*(a: Value, b: Value): int = - if a.kind != b.kind: - return cmp(a.kind, b.kind) - case a.kind - of vkNull: 0 - of vkBool: cmp(a.boolVal, b.boolVal) - of vkInt64: cmp(a.int64Val, b.int64Val) - of vkFloat64: cmp(a.float64Val, b.float64Val) - of vkDecimal: compareDecimals(a, b) - of vkText, vkBlob, vkTextCompressed, vkBlobCompressed: - var lenA = a.bytes.len - var lenB = b.bytes.len - var minLen = min(lenA, lenB) - if minLen > 0: - let c = cmpMem(unsafeAddr a.bytes[0], unsafeAddr b.bytes[0], minLen) - if c != 0: return c - return cmp(lenA, lenB) - else: - 0 + if a.kind == b.kind: + case a.kind + of vkNull: return 0 + of vkBool: return cmp(a.boolVal, b.boolVal) + of vkBoolTrue, vkBoolFalse: return 0 + of vkInt64: return cmp(a.int64Val, b.int64Val) + of vkInt0, vkInt1: return 0 + of vkFloat64: return cmp(a.float64Val, b.float64Val) + of vkDecimal: return compareDecimals(a, b) + of vkText, vkBlob, vkTextCompressed, vkBlobCompressed: + return compareBytes(a, b) + else: return 0 + + # SQLite-compatible type affinity: NULL < INT/REAL < TEXT < BLOB + if a.kind == vkNull: return -1 + if b.kind == vkNull: return 1 + + let aIsInt = isIntKind(a.kind) + let bIsInt = isIntKind(b.kind) + let aIsFloat = a.kind == vkFloat64 + let bIsFloat = b.kind == vkFloat64 + let aIsText = isTextKind(a.kind) + let bIsText = isTextKind(b.kind) + let aIsBlob = isBlobKind(a.kind) + let bIsBlob = isBlobKind(b.kind) + + # INT vs INT (different compact kinds, e.g. vkInt0 vs vkInt64) + if aIsInt and bIsInt: + return cmp(getIntVal(a), getIntVal(b)) + + # INT vs FLOAT or FLOAT vs INT โ€” compare as float + if aIsInt and bIsFloat: + return cmp(float64(getIntVal(a)), b.float64Val) + if aIsFloat and bIsInt: + return cmp(a.float64Val, float64(getIntVal(b))) + + # INT/FLOAT vs TEXT โ€” try to coerce TEXT to numeric (SQLite affinity) + if (aIsInt or aIsFloat) and bIsText: + let s = valueToString(b) + try: + let parsed = parseBiggestInt(s) + if aIsInt: return cmp(getIntVal(a), int64(parsed)) + else: return cmp(a.float64Val, float64(parsed)) + except ValueError: + try: + let parsed = parseFloat(s) + if aIsInt: return cmp(float64(getIntVal(a)), parsed) + else: return cmp(a.float64Val, parsed) + except ValueError: + return -1 # Non-numeric text: numeric < text + + if aIsText and (bIsInt or bIsFloat): + let s = valueToString(a) + try: + let parsed = parseBiggestInt(s) + if bIsInt: return cmp(int64(parsed), getIntVal(b)) + else: return cmp(float64(parsed), b.float64Val) + except ValueError: + try: + let parsed = parseFloat(s) + if bIsInt: return cmp(parsed, float64(getIntVal(b))) + else: return cmp(parsed, b.float64Val) + except ValueError: + return 1 # Non-numeric text: text > numeric + + # TEXT vs TEXT (different sub-kinds like compressed vs uncompressed) + if aIsText and bIsText: + return compareBytes(a, b) + + # BLOB vs BLOB (different sub-kinds) + if aIsBlob and bIsBlob: + return compareBytes(a, b) + + # BLOB > everything else; numeric < text (SQLite ordering) + if aIsBlob: return 1 + if bIsBlob: return -1 + if aIsText: return 1 + if bIsText: return -1 + + # Fallback for bool vs int etc. + cmp(ord(a.kind), ord(b.kind)) proc hash*(v: Value): Hash = var h: Hash = 0 @@ -2166,6 +2283,109 @@ proc evalExpr*(row: Row, expr: Expr, params: seq[Value]): Result[Value] = return ok(textValue(valueToString(strRes.value).replace( valueToString(fromRes.value), valueToString(toRes.value)))) + if name == "JSON_ARRAY_LENGTH": + if expr.args.len < 1 or expr.args.len > 2: + return err[Value](ERR_SQL, "JSON_ARRAY_LENGTH requires 1 or 2 arguments") + let argRes = evalExpr(row, expr.args[0], params) + if not argRes.ok: + return err[Value](argRes.err.code, argRes.err.message, argRes.err.context) + if argRes.value.kind == vkNull: + return ok(Value(kind: vkNull)) + let jsonText = valueToString(argRes.value) + try: + var node = parseJson(jsonText) + if expr.args.len == 2: + let pathRes = evalExpr(row, expr.args[1], params) + if not pathRes.ok: + return err[Value](pathRes.err.code, pathRes.err.message, pathRes.err.context) + if pathRes.value.kind == vkNull: + return ok(Value(kind: vkNull)) + let path = valueToString(pathRes.value) + if path == "$": + discard + elif path.startsWith("$."): + for key in path[2..^1].split('.'): + if node.kind == JObject and node.hasKey(key): + node = node[key] + else: + return ok(Value(kind: vkNull)) + else: + return ok(Value(kind: vkNull)) + if node.kind == JArray: + return ok(Value(kind: vkInt64, int64Val: int64(node.len))) + return ok(Value(kind: vkInt64, int64Val: 0'i64)) + except JsonParsingError: + return ok(Value(kind: vkNull)) + + if name == "JSON_EXTRACT": + if expr.args.len != 2: + return err[Value](ERR_SQL, "JSON_EXTRACT requires exactly 2 arguments") + let jsonRes = evalExpr(row, expr.args[0], params) + if not jsonRes.ok: + return err[Value](jsonRes.err.code, jsonRes.err.message, jsonRes.err.context) + if jsonRes.value.kind == vkNull: + return ok(Value(kind: vkNull)) + let pathRes = evalExpr(row, expr.args[1], params) + if not pathRes.ok: + return err[Value](pathRes.err.code, pathRes.err.message, pathRes.err.context) + if pathRes.value.kind == vkNull: + return ok(Value(kind: vkNull)) + let jsonText = valueToString(jsonRes.value) + let path = valueToString(pathRes.value) + try: + var node = parseJson(jsonText) + if path == "$": + discard + elif path.startsWith("$"): + var segments: seq[string] = @[] + var i = 1 + while i < path.len: + if path[i] == '.': + inc i + var key = "" + while i < path.len and path[i] != '.' and path[i] != '[': + key.add(path[i]) + inc i + if key.len > 0: + segments.add(key) + elif path[i] == '[': + inc i + var idxStr = "" + while i < path.len and path[i] != ']': + idxStr.add(path[i]) + inc i + if i < path.len: inc i # skip ']' + segments.add("[" & idxStr & "]") + else: + inc i + for seg in segments: + if seg.startsWith("[") and seg.endsWith("]"): + let idxText = seg[1..^2] + try: + let idx = parseInt(idxText) + if node.kind == JArray and idx >= 0 and idx < node.len: + node = node[idx] + else: + return ok(Value(kind: vkNull)) + except ValueError: + return ok(Value(kind: vkNull)) + else: + if node.kind == JObject and node.hasKey(seg): + node = node[seg] + else: + return ok(Value(kind: vkNull)) + else: + return ok(Value(kind: vkNull)) + case node.kind + of JString: return ok(textValue(node.getStr())) + of JInt: return ok(Value(kind: vkInt64, int64Val: node.getBiggestInt())) + of JFloat: return ok(Value(kind: vkFloat64, float64Val: node.getFloat())) + of JBool: return ok(Value(kind: vkInt64, int64Val: if node.getBool(): 1'i64 else: 0'i64)) + of JNull: return ok(Value(kind: vkNull)) + of JObject, JArray: return ok(textValue($node)) + except JsonParsingError: + return ok(Value(kind: vkNull)) + if name == "PRINTF": if expr.args.len < 1: return err[Value](ERR_SQL, "PRINTF requires at least a format argument") @@ -2549,7 +2769,7 @@ proc indexSeekRows(pager: Pager, catalog: Catalog, tableName: string, alias: str cols.add(prefix & "." & col.name) var valueIndex = -1 for i, col in table.columns: - if col.name == column: + if col.name.toLowerAscii() == column.toLowerAscii(): valueIndex = i break for rowid in rowIdsRes.value: @@ -2573,7 +2793,7 @@ proc trigramSeekRows(pager: Pager, catalog: Catalog, tableName: string, alias: s let idx = indexOpt.get var columnIndex = -1 for i, col in table.columns: - if col.name == column: + if col.name.toLowerAscii() == column.toLowerAscii(): columnIndex = i break if columnIndex < 0: @@ -3622,7 +3842,14 @@ proc execPlan*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Value]): let valueRes = evalExpr(Row(), plan.valueExpr, params) if not valueRes.ok: return err[seq[Row]](valueRes.err.code, valueRes.err.message, valueRes.err.context) - if valueRes.value.kind != vkInt64: + var seekVal = valueRes.value + if isTextKind(seekVal.kind): + let s = valueToString(seekVal) + try: + seekVal = Value(kind: vkInt64, int64Val: int64(parseBiggestInt(s))) + except ValueError: + return err[seq[Row]](ERR_SQL, "Rowid seek expects INT64") + if seekVal.kind != vkInt64 and seekVal.kind != vkInt0 and seekVal.kind != vkInt1: return err[seq[Row]](ERR_SQL, "Rowid seek expects INT64") let tableRes = catalog.getTable(plan.table) if not tableRes.ok: @@ -3632,7 +3859,7 @@ proc execPlan*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Value]): var cols: seq[string] = @[] for col in table.columns: cols.add(prefix & "." & col.name) - let targetRowId = cast[uint64](valueRes.value.int64Val) + let targetRowId = cast[uint64](getIntVal(seekVal)) let readRes = readRowAt(pager, table, targetRowId) if not readRes.ok: if readRes.err.code == ERR_IO and readRes.err.message == "Key not found": @@ -3891,6 +4118,41 @@ proc execPlan*(pager: Pager, catalog: Catalog, plan: Plan, params: seq[Value]): let prefix = if plan.right.alias.len > 0: plan.right.alias else: plan.right.table for col in tableRes.value.columns: rightColumns.add(prefix & "." & col.name) + # For subquery-backed joins with zero rows, derive column names from the + # inner plan's projection or underlying table so LEFT JOIN can pad NULLs. + if rightColumns.len == 0 and plan.right.kind == pkSubqueryScan: + let prefix = if plan.right.alias.len > 0: plan.right.alias + elif plan.right.table.len > 0: plan.right.table + else: "" + if prefix.len > 0: + var p = plan.right.subPlan + while p != nil: + if p.kind == pkProject and p.projections.len > 0: + for item in p.projections: + if item.alias.len > 0: + rightColumns.add(prefix & "." & item.alias) + elif item.expr != nil and item.expr.kind == ekColumn: + rightColumns.add(prefix & "." & item.expr.name) + else: + rightColumns.add(prefix & ".?column?") + break + if p.left != nil: p = p.left + elif p.subPlan != nil: p = p.subPlan + else: break + # Last resort: find underlying table scan in the inner plan. + if rightColumns.len == 0: + var p = plan.right.subPlan + while p != nil: + if p.kind == pkTableScan and p.table.len > 0: + let tableRes = catalog.getTable(p.table) + if tableRes.ok: + let pfx = if prefix.len > 0: prefix else: p.table + for col in tableRes.value.columns: + rightColumns.add(pfx & "." & col.name) + break + if p.left != nil: p = p.left + elif p.subPlan != nil: p = p.subPlan + else: break # Pre-compute merged column names once (avoid allocating for each match) var mergedColumns: seq[string] = @[] diff --git a/src/planner/planner.nim b/src/planner/planner.nim index 0cbcec0..d562e80 100644 --- a/src/planner/planner.nim +++ b/src/planner/planner.nim @@ -237,8 +237,17 @@ proc planSelect(catalog: Catalog, stmt: Statement): Plan = proc isRowidPkColumn(colName: string): bool = if not tableRes.ok: return false + # A column is the rowid only when it is the SOLE INT64 primary key. + # Composite PKs use a hash-based unique index; individual columns are + # NOT rowid aliases even when they are INTEGER PRIMARY KEY. + var pkCount = 0 for col in tableMeta.columns: - if col.name == colName and col.primaryKey and col.kind == ctInt64: + if col.primaryKey: + inc pkCount + if pkCount != 1: + return false + for col in tableMeta.columns: + if col.name.toLowerAscii() == colName.toLowerAscii() and col.primaryKey and col.kind == ctInt64: return true false @@ -493,7 +502,29 @@ proc planSelect(catalog: Catalog, stmt: Statement): Plan = if stmt.groupBy.len > 0 or hasAggregate(stmt.selectItems): base = Plan(kind: pkAggregate, groupBy: stmt.groupBy, having: stmt.havingExpr, projections: stmt.selectItems, left: base) if stmt.orderBy.len > 0: - base = Plan(kind: pkSort, orderBy: stmt.orderBy, left: base) + # Rewrite ORDER BY aggregate expressions to column references so the + # sort evaluates against materialized aggregate output columns instead + # of trying to re-evaluate the aggregate function (which would fail). + var rewrittenOrder: seq[OrderItem] = @[] + for item in stmt.orderBy: + if exprHasAggregate(item.expr): + var colName = "" + for si in stmt.selectItems: + if exprToCanonicalSql(si.expr) == exprToCanonicalSql(item.expr): + if si.alias.len > 0: + colName = si.alias + elif si.expr != nil and si.expr.kind == ekFunc: + colName = si.expr.funcName.toLowerAscii() + break + if colName.len > 0: + rewrittenOrder.add(OrderItem( + expr: Expr(kind: ekColumn, name: colName, table: ""), + asc: item.asc)) + else: + rewrittenOrder.add(item) + else: + rewrittenOrder.add(item) + base = Plan(kind: pkSort, orderBy: rewrittenOrder, left: base) else: # Place Sort before Project so that ORDER BY sees the original column # names (not aliases) and expensive projections (e.g. scalar subqueries) diff --git a/src/sql/binder.nim b/src/sql/binder.nim index af843c5..df140f7 100644 --- a/src/sql/binder.nim +++ b/src/sql/binder.nim @@ -16,6 +16,16 @@ proc bindStatement*(catalog: Catalog, stmt: Statement): Result[Statement] proc normalizedName(name: string): string = name.toLowerAscii() +proc eqIdent(a, b: string): bool {.inline.} = + ## Case-insensitive identifier comparison following PostgreSQL semantics. + a.toLowerAscii() == b.toLowerAscii() + +proc eqIdentSeq(a, b: seq[string]): bool = + if a.len != b.len: return false + for i in 0.. 0: - map[fromAlias] = baseRes.value + map[normalizedName(fromAlias)] = baseRes.value for join in joins: let tableRes = catalog.getTable(join.table) if not tableRes.ok: return err[Table[string, TableMeta]](tableRes.err.code, tableRes.err.message, tableRes.err.context) - map[join.table] = tableRes.value + map[normalizedName(join.table)] = tableRes.value if join.alias.len > 0: - map[join.alias] = tableRes.value + map[normalizedName(join.alias)] = tableRes.value ok(map) proc resolveColumn(map: Table[string, TableMeta], table: string, name: string): Result[ColumnType] = + let normName = normalizedName(name) if table.len > 0: - if not map.hasKey(table): + let normTable = normalizedName(table) + if not map.hasKey(normTable): return err[ColumnType](ERR_SQL, "Unknown table", table) - let meta = map[table] + let meta = map[normTable] for col in meta.columns: - if col.name == name: + if normalizedName(col.name) == normName: return ok(col.kind) return err[ColumnType](ERR_SQL, "Unknown column", table & "." & name) var found: seq[ColumnType] = @[] for _, meta in map: for col in meta.columns: - if col.name == name: + if normalizedName(col.name) == normName: found.add(col.kind) if found.len == 0: return err[ColumnType](ERR_SQL, "Unknown column", name) @@ -469,6 +481,9 @@ proc inferExprType(catalog: Catalog, map: Table[string, TableMeta], expr: Expr): proc subqueryTableMeta(catalog: Catalog, subquery: Statement, alias: string): Result[TableMeta] = ## Derive a synthetic TableMeta from a subquery's SELECT items. + # For UNION/INTERSECT/EXCEPT, columns are defined by the left operand. + if subquery.setOpKind != sokNone and subquery.setOpLeft != nil: + return subqueryTableMeta(catalog, subquery.setOpLeft, alias) var columns: seq[Column] = @[] # Build a table map for the subquery's own sources to resolve column types. var innerMap = initTable[string, TableMeta]() @@ -1484,6 +1499,16 @@ proc bindSelect(catalog: Catalog, stmt: Statement): Result[Statement] = var found = false for si in expanded.selectItems: if si.alias.len > 0 and si.alias == item.expr.name: + # When the aliased expression is an aggregate function, keep the + # ORDER BY as a column reference to the alias so the sort + # evaluates against the materialized aggregate output row instead + # of trying to re-evaluate the aggregate (which would fail). + if si.expr != nil and si.expr.kind == ekFunc and + si.expr.funcName.toUpperAscii() in [ + "COUNT", "SUM", "AVG", "MIN", "MAX", + "GROUP_CONCAT", "STRING_AGG"]: + found = true + break expanded.orderBy[i] = OrderItem(expr: cloneExpr(si.expr), asc: item.asc) let aliasRes = bindExpr(map, expanded.orderBy[i].expr) if not aliasRes.ok: @@ -1518,9 +1543,9 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = return err[Statement](ERR_SQL, "Column count mismatch", stmt.insertTable) var colSet = initHashSet[string]() for col in viewRes.value.columnNames: - colSet.incl(col) + colSet.incl(normalizedName(col)) for colName in targetCols: - if not colSet.contains(colName): + if not colSet.contains(normalizedName(colName)): return err[Statement](ERR_SQL, "Unknown column", colName) return ok(stmt) let tableRes = catalog.getTable(stmt.insertTable) @@ -1541,9 +1566,9 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = # Validate target columns exist var colIndex = initTable[string, int]() for i, col in table.columns: - colIndex[col.name] = i + colIndex[normalizedName(col.name)] = i for colName in targetCols: - if not colIndex.hasKey(colName): + if not colIndex.hasKey(normalizedName(colName)): return err[Statement](ERR_SQL, "Unknown column", colName) # Validate SELECT column count matches target columns if stmt.insertSelectQuery.selectItems.len > 0: @@ -1569,14 +1594,14 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = return err[Statement](ERR_SQL, "Column count mismatch", stmt.insertTable) var colIndex = initTable[string, int]() for i, col in table.columns: - colIndex[col.name] = i + colIndex[normalizedName(col.name)] = i var ordered: seq[Expr] = newSeq[Expr](table.columns.len) for i in 0 ..< ordered.len: ordered[i] = nullExpr() for i, colName in targetCols: - if not colIndex.hasKey(colName): + if not colIndex.hasKey(normalizedName(colName)): return err[Statement](ERR_SQL, "Unknown column", colName) - let idx = colIndex[colName] + let idx = colIndex[normalizedName(colName)] let expr = stmt.insertValues[i] let typeRes = checkLiteralType(expr, table.columns[idx].kind, colName) if not typeRes.ok: @@ -1592,7 +1617,7 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = for i in 0 ..< rowOrdered.len: rowOrdered[i] = nullExpr() for i, colName in targetCols: - let idx = colIndex[colName] + let idx = colIndex[normalizedName(colName)] let expr = rowVals[i] let typeRes = checkLiteralType(expr, table.columns[idx].kind, colName) if not typeRes.ok: @@ -1613,7 +1638,7 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = if not catalog.indexes.hasKey(conflictTargetConstraint): return err[Statement](ERR_SQL, "Unknown ON CONFLICT constraint", conflictTargetConstraint) let idx = catalog.indexes[conflictTargetConstraint] - if idx.table != stmt.insertTable: + if idx.table.toLowerAscii() != stmt.insertTable.toLowerAscii(): return err[Statement](ERR_SQL, "ON CONFLICT constraint belongs to different table", conflictTargetConstraint) if not idx.unique: return err[Statement](ERR_SQL, "ON CONFLICT constraint must be unique", conflictTargetConstraint) @@ -1622,20 +1647,20 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = if conflictTargetCols.len > 0: var colSet = initHashSet[string]() for col in table.columns: - colSet.incl(col.name) + colSet.incl(normalizedName(col.name)) for name in conflictTargetCols: - if not colSet.contains(name): + if not colSet.contains(normalizedName(name)): return err[Statement](ERR_SQL, "Unknown ON CONFLICT target column", name) var matchedUniqueTarget = false if conflictTargetCols.len == 1: for col in table.columns: - if col.name == conflictTargetCols[0] and (col.primaryKey or col.unique): + if eqIdent(col.name, conflictTargetCols[0]) and (col.primaryKey or col.unique): matchedUniqueTarget = true break if not matchedUniqueTarget: for _, idx in catalog.indexes: - if idx.table == stmt.insertTable and idx.unique and idx.columns == conflictTargetCols: + if eqIdent(idx.table, stmt.insertTable) and idx.unique and eqIdentSeq(idx.columns, conflictTargetCols): matchedUniqueTarget = true break if not matchedUniqueTarget: @@ -1660,7 +1685,7 @@ proc bindInsert(catalog: Catalog, stmt: Statement): Result[Statement] = for colName, expr in stmt.insertConflictUpdateAssignments: var colIdx = -1 for i, col in table.columns: - if col.name == colName: + if eqIdent(col.name, colName): colIdx = i break if colIdx < 0: @@ -1720,9 +1745,9 @@ proc bindUpdate(catalog: Catalog, stmt: Statement): Result[Statement] = return err[Statement](viewRes.err.code, viewRes.err.message, viewRes.err.context) var colSet = initHashSet[string]() for col in viewRes.value.columnNames: - colSet.incl(col) + colSet.incl(normalizedName(col)) for colName, _ in stmt.assignments: - if not colSet.contains(colName): + if not colSet.contains(normalizedName(colName)): return err[Statement](ERR_SQL, "Unknown column", colName) return ok(stmt) let tableRes = catalog.getTable(stmt.updateTable) @@ -1732,7 +1757,7 @@ proc bindUpdate(catalog: Catalog, stmt: Statement): Result[Statement] = for colName, expr in stmt.assignments: var found = false for col in table.columns: - if col.name == colName: + if eqIdent(col.name, colName): found = true let typeRes = checkLiteralType(expr, col.kind, colName) if not typeRes.ok: @@ -1910,7 +1935,7 @@ proc bindCreateIndex(catalog: Catalog, stmt: Statement): Result[Statement] = for colName in stmt.columnNames: var found = false for col in tableRes.value.columns: - if col.name == colName: + if eqIdent(col.name, colName): found = true if stmt.indexKind == sql.ikTrigram and col.kind != ctText: return err[Statement](ERR_SQL, "Trigram index only supported on TEXT columns", colName) @@ -2068,7 +2093,7 @@ proc bindAlterTable(catalog: Catalog, stmt: Statement): Result[Statement] = return err[Statement](tableRes.err.code, tableRes.err.message, tableRes.err.context) let table = tableRes.value for _, idx in catalog.indexes: - if idx.table == stmt.alterTableName and idx.columns.len == 1 and idx.columns[0].startsWith(IndexExpressionPrefix): + if idx.table.toLowerAscii() == stmt.alterTableName.toLowerAscii() and idx.columns.len == 1 and idx.columns[0].startsWith(IndexExpressionPrefix): return err[Statement](ERR_SQL, "ALTER TABLE on tables with expression indexes is not supported in 0.x", stmt.alterTableName) if table.checks.len > 0: return err[Statement](ERR_SQL, "ALTER TABLE on tables with CHECK constraints is not supported in 0.x", stmt.alterTableName) @@ -2077,7 +2102,7 @@ proc bindAlterTable(catalog: Catalog, stmt: Statement): Result[Statement] = case action.kind of ataAddColumn: for col in table.columns: - if col.name == action.columnDef.name: + if eqIdent(col.name, action.columnDef.name): return err[Statement](ERR_SQL, "Column already exists", action.columnDef.name) let validTypes = ["INT", "INT64", "INTEGER", "TEXT", "BLOB", "BOOL", "BOOLEAN", "FLOAT", "FLOAT64", "REAL"] var typeValid = false @@ -2091,7 +2116,7 @@ proc bindAlterTable(catalog: Catalog, stmt: Statement): Result[Statement] = of ataDropColumn: var found = false for col in table.columns: - if col.name == action.columnName: + if eqIdent(col.name, action.columnName): found = true if col.primaryKey: return err[Statement](ERR_SQL, "Cannot drop PRIMARY KEY column", action.columnName) @@ -2104,9 +2129,9 @@ proc bindAlterTable(catalog: Catalog, stmt: Statement): Result[Statement] = var oldFound = false var newFound = false for col in table.columns: - if col.name == action.columnName: + if eqIdent(col.name, action.columnName): oldFound = true - if col.name == action.newColumnName: + if eqIdent(col.name, action.newColumnName): newFound = true if not oldFound: return err[Statement](ERR_SQL, "Column does not exist", action.columnName) @@ -2116,7 +2141,7 @@ proc bindAlterTable(catalog: Catalog, stmt: Statement): Result[Statement] = var found = false var foundCol = Column() for col in table.columns: - if col.name == action.columnName: + if eqIdent(col.name, action.columnName): found = true foundCol = col break diff --git a/src/sql/sql.nim b/src/sql/sql.nim index cd1aaef..1bad7b4 100644 --- a/src/sql/sql.nim +++ b/src/sql/sql.nim @@ -1369,14 +1369,22 @@ proc selectToCanonicalSql(stmt: Statement): string = parts.add("SELECT " & selectItems.join(", ")) if stmt.fromTable.len > 0: - var fromPart = "FROM " & quoteIdent(stmt.fromTable) - if stmt.fromAlias.len > 0: - fromPart.add(" " & quoteIdent(stmt.fromAlias)) + var fromPart: string + if stmt.fromSubquery != nil: + fromPart = "FROM (" & selectToCanonicalSql(stmt.fromSubquery) & ") " & quoteIdent(stmt.fromAlias) + else: + fromPart = "FROM " & quoteIdent(stmt.fromTable) + if stmt.fromAlias.len > 0: + fromPart.add(" " & quoteIdent(stmt.fromAlias)) parts.add(fromPart) - for join in stmt.joins: + for i, join in stmt.joins: let joinKeyword = if join.joinType == jtLeft: "LEFT JOIN " else: "INNER JOIN " - var joinPart = joinKeyword & quoteIdent(join.table) + var joinPart: string + if i < stmt.joinSubqueries.len and stmt.joinSubqueries[i] != nil: + joinPart = joinKeyword & "(" & selectToCanonicalSql(stmt.joinSubqueries[i]) & ")" + else: + joinPart = joinKeyword & quoteIdent(join.table) if join.alias.len > 0: joinPart.add(" " & quoteIdent(join.alias)) if join.onExpr != nil: @@ -1536,7 +1544,7 @@ proc parseCreateStmt(node: JsonNode): Result[Statement] = let colName = nodeString(keyNode) var found = false for i, col in columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): found = true columns[i].unique = keys.len == 1 # only mark unique for single-column if contype == "CONSTR_PRIMARY": diff --git a/src/storage/storage.nim b/src/storage/storage.nim index f21bfcb..c3a5a84 100644 --- a/src/storage/storage.nim +++ b/src/storage/storage.nim @@ -23,6 +23,11 @@ var gReusableBTree {.threadvar.}: BTree # Thread-local reusable record buffer to avoid per-insert allocation var gRecordBuf {.threadvar.}: seq[byte] +proc evictPagerFromReusableBTree*(pager: Pager) = + ## Release gReusableBTree if it references the given pager. + if gReusableBTree != nil and gReusableBTree.pager == pager: + gReusableBTree = nil + type StoredRow* = object rowid*: uint64 values*: seq[Value] @@ -75,7 +80,7 @@ proc isTextBlobIndex*(table: TableMeta, idx: IndexMeta): bool = if idx.columns.len != 1: return false for col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): return col.kind in {ctText, ctBlob} false @@ -140,7 +145,7 @@ proc indexColumnIndices*(table: TableMeta, idx: IndexMeta): seq[int] = if colName.startsWith(IndexExpressionPrefix): continue for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): result.add(i) break @@ -238,7 +243,7 @@ proc evalExpressionIndexExpr(table: TableMeta, values: seq[Value], expr: Expr): case expr.kind of ekColumn: for i, col in table.columns: - if col.name == expr.name: + if col.name.toLowerAscii() == expr.name.toLowerAscii(): if i < values.len: return ok(values[i]) return err[Value](ERR_CORRUPTION, "Row column count does not match table metadata", table.name) @@ -344,7 +349,7 @@ proc evalPredicateValue(table: TableMeta, values: seq[Value], expr: Expr): Value of svParam: return Value(kind: vkNull) of ekColumn: for i, col in table.columns: - if col.name == expr.name: + if col.name.toLowerAscii() == expr.name.toLowerAscii(): if i < values.len: return values[i] return Value(kind: vkNull) return Value(kind: vkNull) @@ -428,7 +433,7 @@ proc evalPredicateValue(table: TableMeta, values: seq[Value], expr: Expr): Value else: return Value(kind: vkNull) -proc shouldIncludeInIndex(table: TableMeta, idx: IndexMeta, values: seq[Value]): bool = +proc shouldIncludeInIndex*(table: TableMeta, idx: IndexMeta, values: seq[Value]): bool = if idx.predicateSql.len == 0: return true let predExpr = parsePredicateExpr(idx.predicateSql, idx.table) @@ -907,7 +912,7 @@ proc insertRowInternal(pager: Pager, catalog: Catalog, tableName: string, values var trigramValues: seq[(IndexMeta, Value)] if updateIndexes: for _, idx in catalog.indexes: - if idx.table != tableName: + if idx.table.toLowerAscii() != tableName.toLowerAscii(): continue if not shouldIncludeInIndex(tablePtr[], idx, storedValues): continue @@ -919,7 +924,7 @@ proc insertRowInternal(pager: Pager, catalog: Catalog, tableName: string, values var idxValue = Value(kind: vkNull) if isTextBlobIndex(tablePtr[], idx) and idx.columns.len == 1: for i, col in tablePtr.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): idxValue = storedValues[i] break indexKeys.add((idx, keyRes.value, idxValue)) @@ -928,7 +933,7 @@ proc insertRowInternal(pager: Pager, catalog: Catalog, tableName: string, values if idx.columns.len == 1: var valueIndex = -1 for i, col in tablePtr.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): valueIndex = i break if valueIndex >= 0: @@ -1037,7 +1042,7 @@ proc insertRowInternal(pager: Pager, catalog: Catalog, tableName: string, values proc insertRow*(pager: Pager, catalog: Catalog, tableName: string, values: seq[Value]): Result[uint64] = var hasIndexes = false for _, idx in catalog.indexes: - if idx.table == tableName: + if idx.table.toLowerAscii() == tableName.toLowerAscii(): hasIndexes = true break insertRowInternal(pager, catalog, tableName, values, hasIndexes) @@ -1122,7 +1127,7 @@ proc updateRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 var hasIndexes = false for _, idx in catalog.indexes: - if idx.table == tableName: + if idx.table.toLowerAscii() == tableName.toLowerAscii(): hasIndexes = true break @@ -1131,7 +1136,7 @@ proc updateRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 if not oldRes.ok: return err[Void](oldRes.err.code, oldRes.err.message, oldRes.err.context) for _, idx in catalog.indexes: - if idx.table != tableName: + if idx.table.toLowerAscii() != tableName.toLowerAscii(): continue let oldIncluded = shouldIncludeInIndex(table, idx, oldRes.value.values) let newIncluded = shouldIncludeInIndex(table, idx, values) @@ -1146,7 +1151,7 @@ proc updateRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 var oldIdxValue = Value(kind: vkNull) if textBlob and idx.columns.len == 1: for i, col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): oldIdxValue = oldRes.value.values[i] break let delRes = deleteKeyValue(idxTree, oldKeyRes.value, encodeIndexEntry(rowid, oldIdxValue, textBlob)) @@ -1159,7 +1164,7 @@ proc updateRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 var newIdxValue = Value(kind: vkNull) if textBlob and idx.columns.len == 1: for i, col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): newIdxValue = values[i] break let idxInsert = insert(idxTree, newKeyRes.value, encodeIndexEntry(rowid, newIdxValue, textBlob)) @@ -1172,7 +1177,7 @@ proc updateRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 if idx.columns.len == 1: var valueIndex = -1 for i, col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): valueIndex = i break if valueIndex >= 0: @@ -1206,7 +1211,7 @@ proc deleteRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 if not oldRes.ok: return err[Void](oldRes.err.code, oldRes.err.message, oldRes.err.context) for _, idx in catalog.indexes: - if idx.table != tableName: + if idx.table.toLowerAscii() != tableName.toLowerAscii(): continue if not shouldIncludeInIndex(table, idx, oldRes.value.values): continue @@ -1219,7 +1224,7 @@ proc deleteRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 var oldIdxValue = Value(kind: vkNull) if textBlob and idx.columns.len == 1: for i, col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): oldIdxValue = oldRes.value.values[i] break let delRes = deleteKeyValue(idxTree, oldKeyRes.value, encodeIndexEntry(rowid, oldIdxValue, textBlob)) @@ -1232,7 +1237,7 @@ proc deleteRow*(pager: Pager, catalog: Catalog, tableName: string, rowid: uint64 if idx.columns.len == 1: var valueIndex = -1 for i, col in table.columns: - if col.name == idx.columns[0]: + if col.name.toLowerAscii() == idx.columns[0].toLowerAscii(): valueIndex = i break if valueIndex >= 0: @@ -1253,7 +1258,7 @@ proc buildIndexForColumn*(pager: Pager, catalog: Catalog, tableName: string, col let table = tableRes.value var columnIndex = -1 for i, col in table.columns: - if col.name == columnName: + if col.name.toLowerAscii() == columnName.toLowerAscii(): columnIndex = i break if columnIndex < 0: @@ -1291,7 +1296,7 @@ proc buildIndexForColumns*(pager: Pager, catalog: Catalog, tableName: string, co for colName in columnNames: var found = false for i, col in table.columns: - if col.name == colName: + if col.name.toLowerAscii() == colName.toLowerAscii(): colIndices.add(i) found = true break @@ -1360,7 +1365,7 @@ proc buildTrigramIndexForColumn*(pager: Pager, catalog: Catalog, tableName: stri let table = tableRes.value var columnIndex = -1 for i, col in table.columns: - if col.name == columnName: + if col.name.toLowerAscii() == columnName.toLowerAscii(): columnIndex = i break if columnIndex < 0: @@ -1470,7 +1475,7 @@ proc indexSeek*(pager: Pager, catalog: Catalog, tableName: string, column: strin var idxOpt: Option[IndexMeta] = none(IndexMeta) if isExpressionIndexToken(column): for _, idx in catalog.indexes: - if idx.table == tableName and idx.kind == ikBtree and idx.columns.len == 1 and idx.columns[0] == column: + if idx.table.toLowerAscii() == tableName.toLowerAscii() and idx.kind == ikBtree and idx.columns.len == 1 and idx.columns[0].toLowerAscii() == column.toLowerAscii(): idxOpt = some(idx) break else: @@ -1661,7 +1666,7 @@ proc alterColumnTypeInTable(pager: Pager, catalog: Catalog, table: var TableMeta let originalTableMeta = table var columnIndex = -1 for i, col in table.columns: - if col.name == columnName: + if col.name.toLowerAscii() == columnName.toLowerAscii(): columnIndex = i break if columnIndex < 0: @@ -1734,7 +1739,7 @@ proc alterColumnTypeInTable(pager: Pager, catalog: Catalog, table: var TableMeta # Rebuild indexes against the rewritten table contents and updated column type. updateTableMeta(catalog, table) for _, idx in catalog.indexes: - if idx.table == table.name: + if idx.table.toLowerAscii() == table.name.toLowerAscii(): let rebuildRes = rebuildIndex(pager, catalog, idx) if not rebuildRes.ok: updateTableMeta(catalog, originalTableMeta) @@ -1744,7 +1749,7 @@ proc alterColumnTypeInTable(pager: Pager, catalog: Catalog, table: var TableMeta proc dropColumnFromTable(pager: Pager, catalog: Catalog, table: var TableMeta, columnName: string): Result[Void] = var columnIndex = -1 for i, col in table.columns: - if col.name == columnName: + if col.name.toLowerAscii() == columnName.toLowerAscii(): columnIndex = i break @@ -1758,7 +1763,7 @@ proc dropColumnFromTable(pager: Pager, catalog: Catalog, table: var TableMeta, c var indexesToDrop: seq[string] = @[] for idxName, idx in catalog.indexes: - if idx.table == table.name and idx.columns.len == 1 and idx.columns[0] == columnName: + if idx.table.toLowerAscii() == table.name.toLowerAscii() and idx.columns.len == 1 and idx.columns[0].toLowerAscii() == columnName.toLowerAscii(): indexesToDrop.add(idxName) for idxName in indexesToDrop: @@ -1813,7 +1818,7 @@ proc dropColumnFromTable(pager: Pager, catalog: Catalog, table: var TableMeta, c table.rootPage = newRoot for idxName, idx in catalog.indexes: - if idx.table == table.name: + if idx.table.toLowerAscii() == table.name.toLowerAscii(): let rebuildRes = rebuildIndex(pager, catalog, idx) if not rebuildRes.ok: return rebuildRes @@ -1822,7 +1827,7 @@ proc dropColumnFromTable(pager: Pager, catalog: Catalog, table: var TableMeta, c proc addColumnToTable(pager: Pager, catalog: Catalog, table: var TableMeta, colDef: ColumnDef): Result[Void] = for col in table.columns: - if col.name == colDef.name: + if col.name.toLowerAscii() == colDef.name.toLowerAscii(): return err[Void](ERR_SQL, "Column already exists", colDef.name) let colRes = columnFromColumnDef(colDef) @@ -1877,7 +1882,7 @@ proc addColumnToTable(pager: Pager, catalog: Catalog, table: var TableMeta, colD table.rootPage = newRoot for idxName, idx in catalog.indexes: - if idx.table == table.name: + if idx.table.toLowerAscii() == table.name.toLowerAscii(): let rebuildRes = rebuildIndex(pager, catalog, idx) if not rebuildRes.ok: return rebuildRes @@ -1887,9 +1892,9 @@ proc addColumnToTable(pager: Pager, catalog: Catalog, table: var TableMeta, colD proc renameColumnInTable(pager: Pager, catalog: Catalog, table: var TableMeta, oldName: string, newName: string): Result[Void] = var oldIndex = -1 for i, col in table.columns: - if col.name == oldName: + if col.name.toLowerAscii() == oldName.toLowerAscii(): oldIndex = i - if col.name == newName: + if col.name.toLowerAscii() == newName.toLowerAscii(): return err[Void](ERR_SQL, "Column already exists", newName) if oldIndex < 0: return err[Void](ERR_SQL, "Column not found", oldName) diff --git a/tests/nim/test_case_insensitive_identifiers.nim b/tests/nim/test_case_insensitive_identifiers.nim new file mode 100644 index 0000000..6f7b8bd --- /dev/null +++ b/tests/nim/test_case_insensitive_identifiers.nim @@ -0,0 +1,205 @@ +# Case-Insensitive Identifier Resolution Tests (ADR-0096) +# +# Verifies that DecentDB follows PostgreSQL semantics where unquoted identifiers +# are case-insensitive. Tables, columns, and indexes created with quoted identifiers +# must be accessible via unquoted (lowercased) identifiers and vice versa. + +import unittest +import os +import strutils + +import engine + +proc makeTempDb(name: string): string = + let path = getTempDir() / name + if fileExists(path): + removeFile(path) + let walPath = path & "-wal" + if fileExists(walPath): + removeFile(walPath) + path + +suite "Case-insensitive table names": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_tables.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "quoted CREATE TABLE accessible via unquoted SELECT": + check execSql(db, """CREATE TABLE "MyTable" ("Id" INTEGER PRIMARY KEY, "Name" TEXT)""").ok + check execSql(db, """INSERT INTO "MyTable" ("Id", "Name") VALUES (1, 'Alice')""").ok + + let sel = execSql(db, "SELECT id, name FROM mytable WHERE id = 1") + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "1|Alice" + + test "unquoted CREATE TABLE accessible via quoted SELECT": + check execSql(db, "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)").ok + check execSql(db, "INSERT INTO users (id, email) VALUES (1, 'a@b.com')").ok + + let sel = execSql(db, """SELECT "id", "email" FROM "users" WHERE "id" = 1""") + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "1|a@b.com" + + test "mixed case CREATE TABLE accessible both ways": + check execSql(db, """CREATE TABLE "ArtistStaging" ("ArtistId" INTEGER PRIMARY KEY, "ArtistName" TEXT)""").ok + check execSql(db, """INSERT INTO "ArtistStaging" ("ArtistId", "ArtistName") VALUES (1, 'Beatles')""").ok + + # Unquoted access + let sel1 = execSql(db, "SELECT artistid, artistname FROM artiststaging") + check sel1.ok + check sel1.value.len == 1 + check sel1.value[0] == "1|Beatles" + + # Mixed quoting + let sel2 = execSql(db, """SELECT "ArtistId", artistname FROM "ArtistStaging" """) + check sel2.ok + check sel2.value.len == 1 + +suite "Case-insensitive column names": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_columns.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "INSERT with unquoted columns into quoted-CREATE table": + check execSql(db, """CREATE TABLE "Items" ("ItemId" INTEGER PRIMARY KEY, "ItemName" TEXT, "Price" REAL)""").ok + + let ins = execSql(db, "INSERT INTO items (itemid, itemname, price) VALUES (1, 'Widget', 9.99)") + check ins.ok + + let sel = execSql(db, "SELECT itemname, price FROM items WHERE itemid = 1") + check sel.ok + check sel.value.len == 1 + check "Widget" in sel.value[0] + + test "UPDATE with unquoted columns on quoted-CREATE table": + check execSql(db, """CREATE TABLE "Products" ("ProductId" INTEGER PRIMARY KEY, "ProductName" TEXT)""").ok + check execSql(db, """INSERT INTO "Products" ("ProductId", "ProductName") VALUES (1, 'Old')""").ok + + check execSql(db, "UPDATE products SET productname = 'New' WHERE productid = 1").ok + + let sel = execSql(db, "SELECT productname FROM products WHERE productid = 1") + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "New" + + test "DELETE with unquoted WHERE on quoted-CREATE table": + check execSql(db, """CREATE TABLE "Logs" ("LogId" INTEGER PRIMARY KEY, "Message" TEXT)""").ok + check execSql(db, """INSERT INTO "Logs" ("LogId", "Message") VALUES (1, 'keep')""").ok + check execSql(db, """INSERT INTO "Logs" ("LogId", "Message") VALUES (2, 'delete')""").ok + + check execSql(db, "DELETE FROM logs WHERE logid = 2").ok + + let sel = execSql(db, "SELECT logid FROM logs") + check sel.ok + check sel.value.len == 1 + +suite "Case-insensitive INSERT ON CONFLICT": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_conflict.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "ON CONFLICT with unquoted column targets on quoted-CREATE table": + check execSql(db, """CREATE TABLE "Settings" ("Key" TEXT PRIMARY KEY, "Value" TEXT)""").ok + check execSql(db, """INSERT INTO "Settings" ("Key", "Value") VALUES ('color', 'red')""").ok + + check execSql(db, "INSERT INTO settings (key, value) VALUES ('color', 'blue') ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value").ok + + let sel = execSql(db, "SELECT value FROM settings WHERE key = 'color'") + check sel.ok + check sel.value[0] == "blue" + +suite "Case-insensitive index names": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_indexes.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "CREATE INDEX with quoted table/column, query with unquoted": + check execSql(db, """CREATE TABLE "Events" ("EventId" INTEGER PRIMARY KEY, "EventDate" TEXT, "Title" TEXT)""").ok + check execSql(db, """CREATE INDEX "idx_events_date" ON "Events" ("EventDate")""").ok + check execSql(db, "INSERT INTO events (eventid, eventdate, title) VALUES (1, '2025-01-01', 'New Year')").ok + + let sel = execSql(db, "SELECT title FROM events WHERE eventdate = '2025-01-01'") + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "New Year" + +suite "Case-insensitive JOINs": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_joins.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "JOIN with mixed quoting": + check execSql(db, """CREATE TABLE "Authors" ("AuthorId" INTEGER PRIMARY KEY, "AuthorName" TEXT)""").ok + check execSql(db, """CREATE TABLE "Books" ("BookId" INTEGER PRIMARY KEY, "AuthorId" INTEGER, "Title" TEXT)""").ok + check execSql(db, """INSERT INTO "Authors" ("AuthorId", "AuthorName") VALUES (1, 'Tolkien')""").ok + check execSql(db, """INSERT INTO "Books" ("BookId", "AuthorId", "Title") VALUES (1, 1, 'The Hobbit')""").ok + + let sel = execSql(db, """ + SELECT a.authorname, b.title + FROM authors a + INNER JOIN books b ON a.authorid = b.authorid + """) + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "Tolkien|The Hobbit" + +suite "Case-insensitive subqueries": + var db: Db + + setup: + let path = makeTempDb("test_case_ident_subquery.ddb") + let openRes = openDb(path, cachePages = 64) + check openRes.ok + db = openRes.value + + teardown: + discard closeDb(db) + + test "INSERT...SELECT with mixed quoting": + check execSql(db, """CREATE TABLE "Staging" ("StagingId" INTEGER PRIMARY KEY, "Val" TEXT)""").ok + check execSql(db, """CREATE TABLE "Final" ("FinalId" INTEGER PRIMARY KEY, "Val" TEXT)""").ok + check execSql(db, """INSERT INTO "Staging" ("StagingId", "Val") VALUES (1, 'data')""").ok + + check execSql(db, "INSERT INTO final (finalid, val) SELECT stagingid, val FROM staging").ok + + let sel = execSql(db, "SELECT val FROM final") + check sel.ok + check sel.value.len == 1 + check sel.value[0] == "data" diff --git a/tests/nim/test_sql_exec.nim b/tests/nim/test_sql_exec.nim index 5cb3866..3045384 100644 --- a/tests/nim/test_sql_exec.nim +++ b/tests/nim/test_sql_exec.nim @@ -829,14 +829,10 @@ suite "Planner": check execSql(db, "INSERT INTO items VALUES (4, 20)").ok check execSql(db, "CREATE INDEX items_val_partial ON items (val) WHERE val IS NOT NULL").ok + # Partial indexes are excluded from the general query planner (ADR-0100) + # so the query falls back to sequential scan. Verify data correctness below. let explainRes = execSql(db, "EXPLAIN SELECT id FROM items WHERE val = 10") check explainRes.ok - var sawIndexSeek = false - for line in explainRes.value: - if line.contains("IndexSeek("): - sawIndexSeek = true - break - check sawIndexSeek let q1 = execSql(db, "SELECT id FROM items WHERE val = 10 ORDER BY id") check q1.ok diff --git a/tests/test_capi_reopen b/tests/test_capi_reopen new file mode 100755 index 0000000..2bfb82f Binary files /dev/null and b/tests/test_capi_reopen differ diff --git a/tests/test_join_reopen b/tests/test_join_reopen new file mode 100755 index 0000000..d91f5ca Binary files /dev/null and b/tests/test_join_reopen differ diff --git a/tests/test_many_tables_reopen b/tests/test_many_tables_reopen new file mode 100755 index 0000000..64ca005 Binary files /dev/null and b/tests/test_many_tables_reopen differ