From c677d7c2e82912ee59aa93c7b0666341e9f62fe0 Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Sun, 12 Apr 2026 17:16:37 -0400 Subject: [PATCH 1/6] refactor: add diff Result alias + DiffErrorBuilder takes Into<_> for each statement --- src/diff.rs | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/diff.rs b/src/diff.rs index 26bbac6..45f49a4 100644 --- a/src/diff.rs +++ b/src/diff.rs @@ -36,10 +36,10 @@ impl fmt::Display for DiffError { #[bon] impl DiffError { #[builder] - fn new( + pub(crate) fn new( kind: DiffErrorKind, - statement_a: Option, - statement_b: Option, + #[builder(into)] statement_a: Option, + #[builder(into)] statement_b: Option, ) -> Self { Self { kind, @@ -51,7 +51,7 @@ impl DiffError { #[derive(Error, Debug)] #[non_exhaustive] -enum DiffErrorKind { +pub enum DiffErrorKind { #[error("can't drop unnamed index")] DropUnnamedIndex, #[error("can't compare unnamed index")] @@ -62,16 +62,18 @@ enum DiffErrorKind { NotImplemented, } +pub type Result = std::result::Result; + pub(crate) trait Diff: Sized { type Diff; - fn diff(&self, other: &Self) -> Result; + fn diff(&self, other: &Self) -> Result; } impl Diff for Vec { type Diff = Option>; - fn diff(&self, other: &Self) -> Result { + fn diff(&self, other: &Self) -> Result { let res = self .iter() .filter_map(|sa| { @@ -155,10 +157,10 @@ fn find_and_compare( other: &[Statement], match_fn: MF, drop_fn: DF, -) -> Result>, DiffError> +) -> Result>> where MF: Fn(&&Statement) -> bool, - DF: Fn() -> Result>, DiffError>, + DF: Fn() -> Result>>, { other.iter().find(match_fn).map_or_else( // drop the statement if it wasn't found in `other` @@ -172,7 +174,7 @@ fn find_and_compare_create_table( sa: &Statement, a: &CreateTable, other: &[Statement], -) -> Result>, DiffError> { +) -> Result>> { find_and_compare( sa, other, @@ -199,7 +201,7 @@ fn find_and_compare_create_index( sa: &Statement, a: &CreateIndex, other: &[Statement], -) -> Result>, DiffError> { +) -> Result>> { find_and_compare( sa, other, @@ -233,7 +235,7 @@ fn find_and_compare_create_type( sa: &Statement, a_name: &ObjectName, other: &[Statement], -) -> Result>, DiffError> { +) -> Result>> { find_and_compare( sa, other, @@ -262,7 +264,7 @@ fn find_and_compare_create_extension( if_not_exists: bool, cascade: bool, other: &[Statement], -) -> Result>, DiffError> { +) -> Result>> { find_and_compare( sa, other, @@ -288,7 +290,7 @@ fn find_and_compare_create_domain( orig: &Statement, domain: &CreateDomain, other: &[Statement], -) -> Result>, DiffError> { +) -> Result>> { let res = other .iter() .find(|sb| match sb { @@ -304,7 +306,7 @@ fn find_and_compare_create_domain( impl Diff for Statement { type Diff = Option>; - fn diff(&self, other: &Self) -> Result { + fn diff(&self, other: &Self) -> Result { match self { Self::CreateTable(a) => match other { Self::CreateTable(b) => Ok(compare_create_table(a, b)), @@ -392,10 +394,7 @@ fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option Result>, DiffError> { +fn compare_create_index(a: &CreateIndex, b: &CreateIndex) -> Result>> { if a == b { return Ok(None); } @@ -431,7 +430,7 @@ fn compare_create_type( b: &Statement, b_name: &ObjectName, b_rep: &Option, -) -> Result>, DiffError> { +) -> Result>> { if a_name == b_name && a_rep == b_rep { return Ok(None); } From 30cb1a09deef9484af50121fd766db27d5e6fcd2 Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Sun, 12 Apr 2026 17:35:18 -0400 Subject: [PATCH 2/6] refactor: route sqlparser imports via new ast module --- src/ast.rs | 8 ++++++++ src/diff.rs | 23 ++++++++++++----------- src/lib.rs | 8 +++++--- src/migration.rs | 14 ++++++-------- src/name_gen.rs | 11 ++++++----- 5 files changed, 37 insertions(+), 27 deletions(-) create mode 100644 src/ast.rs diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000..0e3218b --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,8 @@ +pub use sqlparser::ast::{ + helpers::attached_token::AttachedToken, AlterColumnOperation, AlterTable, AlterTableOperation, + AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, + AlterTypeRenameValue, ColumnDef, ColumnOption, ColumnOptionDef, CreateDomain, CreateExtension, + CreateIndex, CreateTable, DropDomain, DropExtension, GeneratedAs, Ident, ObjectName, + ObjectNamePart, ObjectType, ReferentialAction, RenameTableNameKind, Statement, + UserDefinedTypeRepresentation, +}; diff --git a/src/diff.rs b/src/diff.rs index 45f49a4..a4ebb84 100644 --- a/src/diff.rs +++ b/src/diff.rs @@ -1,14 +1,15 @@ use std::{cmp::Ordering, collections::HashSet, fmt}; use bon::bon; -use sqlparser::ast::{ - helpers::attached_token::AttachedToken, AlterTable, AlterTableOperation, AlterType, - AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, CreateDomain, - CreateExtension, CreateIndex, CreateTable, DropDomain, DropExtension, Ident, ObjectName, - ObjectType, Statement, UserDefinedTypeRepresentation, -}; use thiserror::Error; +use crate::ast::{ + AlterTable, AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, + AlterTypeOperation, AttachedToken, CreateDomain, CreateExtension, CreateIndex, CreateTable, + DropDomain, DropExtension, Ident, ObjectName, ObjectType, Statement, + UserDefinedTypeRepresentation, +}; + #[derive(Error, Debug)] pub struct DiffError { kind: DiffErrorKind, @@ -184,7 +185,7 @@ fn find_and_compare_create_table( }, || { Ok(Some(vec![Statement::Drop { - object_type: sqlparser::ast::ObjectType::Table, + object_type: crate::ast::ObjectType::Table, if_exists: a.if_not_exists, names: vec![a.name.clone()], cascade: false, @@ -218,7 +219,7 @@ fn find_and_compare_create_index( })?; Ok(Some(vec![Statement::Drop { - object_type: sqlparser::ast::ObjectType::Index, + object_type: crate::ast::ObjectType::Index, if_exists: a.if_not_exists, names: vec![name], cascade: false, @@ -245,7 +246,7 @@ fn find_and_compare_create_type( }, || { Ok(Some(vec![Statement::Drop { - object_type: sqlparser::ast::ObjectType::Type, + object_type: crate::ast::ObjectType::Type, if_exists: false, names: vec![a_name.clone()], cascade: false, @@ -277,7 +278,7 @@ fn find_and_compare_create_extension( names: vec![a_name.clone()], if_exists: if_not_exists, cascade_or_restrict: if cascade { - Some(sqlparser::ast::ReferentialAction::Cascade) + Some(crate::ast::ReferentialAction::Cascade) } else { None }, @@ -448,7 +449,7 @@ fn compare_create_type( None } else { Some(AlterTypeOperation::RenameValue( - sqlparser::ast::AlterTypeRenameValue { + crate::ast::AlterTypeRenameValue { from: a.clone(), to: b.clone(), }, diff --git a/src/lib.rs b/src/lib.rs index 963127c..069727e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,17 @@ use std::fmt; use bon::bon; -use diff::Diff; -use migration::Migrate; use sqlparser::{ - ast::Statement, dialect::{self}, parser::{self, Parser}, }; use thiserror::Error; +use ast::Statement; +use diff::Diff; +use migration::Migrate; + +mod ast; mod diff; mod migration; pub mod name_gen; diff --git a/src/migration.rs b/src/migration.rs index 72c9aa2..ff2d1bd 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -1,12 +1,13 @@ use std::fmt; use bon::bon; -use sqlparser::ast::{ +use thiserror::Error; + +use crate::ast::{ AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition, AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension, GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation, }; -use thiserror::Error; #[derive(Error, Debug)] pub struct MigrateError { @@ -267,10 +268,8 @@ fn migrate_alter_table( t.columns.push(column_def.clone()); } AlterTableOperation::DropColumn { column_names, .. } => { - t.columns.retain(|c| { - !column_names - .iter().any(|name| c.name.value == name.value) - }); + t.columns + .retain(|c| !column_names.iter().any(|name| c.name.value == name.value)); } AlterTableOperation::AlterColumn { column_name, op } => { t.columns.iter_mut().for_each(|c| { @@ -316,8 +315,7 @@ fn migrate_alter_table( c.options.push(ColumnOptionDef { name: None, option: ColumnOption::Generated { - generated_as: (*generated_as) - .unwrap_or(GeneratedAs::Always), + generated_as: (*generated_as).unwrap_or(GeneratedAs::Always), sequence_options: sequence_options.clone(), generation_expr: None, generation_expr_mode: None, diff --git a/src/name_gen.rs b/src/name_gen.rs index 0825879..178ea36 100644 --- a/src/name_gen.rs +++ b/src/name_gen.rs @@ -1,10 +1,11 @@ -use sqlparser::ast::{ - AlterTable, AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, ObjectName, - ObjectType, RenameTableNameKind, Statement, +use crate::{ + ast::{ + AlterTable, AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, + ObjectName, ObjectType, RenameTableNameKind, Statement, + }, + SyntaxTree, }; -use crate::SyntaxTree; - #[bon::builder(finish_fn = build)] pub fn generate_name( #[builder(start_fn)] tree: &SyntaxTree, From d20be9cd528552903a33da9680e5acfee6e2dccb Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Mon, 13 Apr 2026 10:34:41 -0400 Subject: [PATCH 3/6] refactor: each test case has own fn via macro_rules --- src/lib.rs | 676 ++++++++++++++++++++---------------------------- src/name_gen.rs | 132 +++++----- 2 files changed, 359 insertions(+), 449 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 069727e..0e83e0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,6 +127,30 @@ impl fmt::Display for SyntaxTree { mod tests { use super::*; + macro_rules! test_case { + ( + @dialect($dialect:path) $(,)? + + $( + $test_name:ident { $( $field:ident : $value:literal ),+ $(,)? } + ),* $(,)? + + => $test_fn:expr $(,)? + ) => { + $( + #[test] + fn $test_name() { + let test_case = TestCase { + dialect: $dialect, + $( $field : $value ),+ + }; + + run_test_case(&test_case, $test_fn); + } + )* + }; + } + #[derive(Debug)] struct TestCase { dialect: Dialect, @@ -135,9 +159,10 @@ mod tests { expect: &'static str, } - fn run_test_case(tc: &TestCase, testfn: F) + fn run_test_case(tc: &TestCase, testfn: F) where - F: Fn(SyntaxTree, SyntaxTree) -> SyntaxTree, + F: Fn(SyntaxTree, SyntaxTree) -> Result, E>, + E: std::error::Error, { let ast_a = SyntaxTree::builder() .dialect(tc.dialect) @@ -154,463 +179,334 @@ mod tests { .sql(tc.expect) .build() .unwrap_or_else(|_| panic!("invalid SQL: {:?}", tc.expect)); - let actual = testfn(ast_a, ast_b); + let actual = testfn(ast_a, ast_b) + .inspect_err(|err| eprintln!("Error: {err:?}")) + .unwrap() + .unwrap(); assert_eq!(actual.to_string(), tc.expect, "{tc:?}"); } - fn run_test_cases(test_cases: Vec, testfn: F) - where - F: Fn(SyntaxTree, SyntaxTree) -> Result, E>, - { - test_cases.into_iter().for_each(|tc| { - run_test_case(&tc, |ast_a, ast_b| { - testfn(ast_a, ast_b) - .inspect_err(|err| eprintln!("Error: {err:?}")) - .unwrap() - .unwrap() - }) - }); - } + mod test_diff { + use super::*; - #[test] - fn diff_create_table() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE foo(\ - id int PRIMARY KEY - )", - sql_b: "CREATE TABLE foo(\ - id int PRIMARY KEY - );\ - CREATE TABLE bar (id INT PRIMARY KEY);", - expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE foo(\ - id int PRIMARY KEY - )", - sql_b: "CREATE TABLE foo(\ - \"id\" int PRIMARY KEY - );\ - CREATE TABLE bar (id INT PRIMARY KEY);", - expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE foo(\ - \"id\" int PRIMARY KEY - )", - sql_b: "CREATE TABLE foo(\ - id int PRIMARY KEY - );\ - CREATE TABLE bar (id INT PRIMARY KEY);", - expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }, - ], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } + test_case!( + @dialect(Dialect::Generic) + + create_table_a { + sql_a: "CREATE TABLE foo(\ + id int PRIMARY KEY + )", + sql_b: "CREATE TABLE foo(\ + id int PRIMARY KEY + );\ + CREATE TABLE bar (id INT PRIMARY KEY);", + expect: "CREATE TABLE bar (id INT PRIMARY KEY);", + }, + + create_table_b { + sql_a: "CREATE TABLE foo(\ + id int PRIMARY KEY + )", + sql_b: "CREATE TABLE foo(\ + \"id\" int PRIMARY KEY + );\ + CREATE TABLE bar (id INT PRIMARY KEY);", + expect: "CREATE TABLE bar (id INT PRIMARY KEY);", + }, + + create_table_c { + sql_a: "CREATE TABLE foo(\ + \"id\" int PRIMARY KEY + )", + sql_b: "CREATE TABLE foo(\ + id int PRIMARY KEY + );\ + CREATE TABLE bar (id INT PRIMARY KEY);", + expect: "CREATE TABLE bar (id INT PRIMARY KEY);", + }, - #[test] - fn diff_drop_table() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + drop_table_a { sql_a: "CREATE TABLE foo(\ - id int PRIMARY KEY - );\ + id int PRIMARY KEY + );\ CREATE TABLE bar (id INT PRIMARY KEY);", sql_b: "CREATE TABLE foo(\ - id int PRIMARY KEY - )", + id int PRIMARY KEY + )", expect: "DROP TABLE bar;", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } + }, - #[test] - fn diff_add_column() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + add_column_a { sql_a: "CREATE TABLE foo(\ - id int PRIMARY KEY - )", + id int PRIMARY KEY + )", sql_b: "CREATE TABLE foo(\ - id int PRIMARY KEY, - bar text - )", + id int PRIMARY KEY, + bar text + )", expect: "ALTER TABLE\n foo\nADD\n COLUMN bar TEXT;", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } + }, - #[test] - fn diff_drop_column() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + drop_column_a { sql_a: "CREATE TABLE foo(\ - id int PRIMARY KEY, - bar text - )", + id int PRIMARY KEY, + bar text + )", sql_b: "CREATE TABLE foo(\ - id int PRIMARY KEY - )", + id int PRIMARY KEY + )", expect: "ALTER TABLE\n foo DROP COLUMN bar;", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } - - #[test] - fn diff_create_index() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", - sql_b: "CREATE UNIQUE INDEX title_idx ON films ((lower(title)));", - expect: "DROP INDEX title_idx;\n\nCREATE UNIQUE INDEX title_idx ON films((lower(title)));", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films (title);", - sql_b: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films ((lower(title)));", - expect: "DROP INDEX IF EXISTS title_idx;\n\nCREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films((lower(title)));", - }, - ], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } - - #[test] - fn diff_create_type() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", - sql_b: "CREATE TYPE foo AS ENUM ('bar');", - expect: "DROP TYPE bug_status;\n\nCREATE TYPE foo AS ENUM ('bar');", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'assigned'\nAFTER\n 'open';", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'closed';", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'assigned';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed';", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'critical');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed', 'critical');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'assigned'\nAFTER\n 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed'\nAFTER\n 'assigned';", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open');", - sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", - expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed';", - }, - ], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } - - #[test] - fn diff_create_extension() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + }, + + create_index_a { + sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", + sql_b: "CREATE UNIQUE INDEX title_idx ON films ((lower(title)));", + expect: "DROP INDEX title_idx;\n\nCREATE UNIQUE INDEX title_idx ON films((lower(title)));", + }, + + create_index_b { + sql_a: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films (title);", + sql_b: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films ((lower(title)));", + expect: "DROP INDEX IF EXISTS title_idx;\n\nCREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films((lower(title)));", + }, + + create_type_a { + sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", + sql_b: "CREATE TYPE foo AS ENUM ('bar');", + expect: "DROP TYPE bug_status;\n\nCREATE TYPE foo AS ENUM ('bar');", + }, + + create_type_b { + sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'assigned'\nAFTER\n 'open';", + }, + + create_type_c { + sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';", + }, + + create_type_d { + sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'closed';", + }, + + create_type_e { + sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'assigned';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed';", + }, + + create_type_f { + sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'critical');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed', 'critical');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'assigned'\nAFTER\n 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed'\nAFTER\n 'assigned';", + }, + + create_type_g { + sql_a: "CREATE TYPE bug_status AS ENUM ('open');", + sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');", + expect: "ALTER TYPE bug_status\nADD\n VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n VALUE 'closed';", + }, + + create_extension_a { sql_a: "CREATE EXTENSION hstore;", sql_b: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", expect: "DROP EXTENSION hstore;\n\nCREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), + }, + + => |ast_a, ast_b| { + ast_a.diff(&ast_b) + } ); - } - #[test] - fn diff_create_domain() { - run_test_cases( - vec![TestCase { - dialect: Dialect::PostgreSql, + test_case!( + @dialect(Dialect::Generic) + + create_domain_a { sql_a: "", sql_b: "CREATE DOMAIN email AS VARCHAR(255) CHECK (VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$');", expect: "CREATE DOMAIN email AS VARCHAR(255) CHECK (\n VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n);", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), - ); - } + }, - #[test] - fn diff_edit_domain() { - run_test_cases( - vec![TestCase { - dialect: Dialect::PostgreSql, + edit_domain_a { sql_a: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0);", sql_b: "CREATE DOMAIN positive_int AS BIGINT CHECK (VALUE > 0 AND VALUE < 1000000);", expect: "DROP DOMAIN IF EXISTS positive_int;\n\nCREATE DOMAIN positive_int AS BIGINT CHECK (\n VALUE > 0\n AND VALUE < 1000000\n);", - }], - |ast_a, ast_b| ast_a.diff(&ast_b), + }, + + => |ast_a, ast_b| { + ast_a.diff(&ast_b) + } ); } - #[test] - fn apply_create_table() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + mod migrate { + use super::*; + + test_case!( + @dialect(Dialect::Generic) + + create_table_a { sql_a: "CREATE TABLE bar (id INT PRIMARY KEY);", sql_b: "CREATE TABLE foo (id INT PRIMARY KEY);", expect: "CREATE TABLE bar (id INT PRIMARY KEY);\n\nCREATE TABLE foo (id INT PRIMARY KEY);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_drop_table() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + drop_table_a { sql_a: "CREATE TABLE bar (id INT PRIMARY KEY)", sql_b: "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)", expect: "CREATE TABLE foo (id INT PRIMARY KEY);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_alter_table_add_column() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + alter_table_add_column_a { sql_a: "CREATE TABLE bar (id INT PRIMARY KEY)", sql_b: "ALTER TABLE bar ADD COLUMN bar TEXT", expect: "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_alter_table_drop_column() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + alter_table_drop_column_a { sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", sql_b: "ALTER TABLE bar DROP COLUMN bar", expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } - - #[test] - fn apply_alter_table_drop_columns_snowflake() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Snowflake, - sql_a: "CREATE TABLE bar (foo TEXT, bar TEXT, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar DROP COLUMN foo, bar", - expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_alter_table_alter_column() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar SET NOT NULL", - expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP NOT NULL", - expect: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar TEXT NOT NULL DEFAULT 'foo', id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP DEFAULT", - expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE INTEGER", - expect: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::PostgreSql, - sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE timestamp with time zone\n USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second'", - expect: "CREATE TABLE bar (bar TIMESTAMP WITH TIME ZONE, id INT PRIMARY KEY);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED BY DEFAULT AS IDENTITY", - expect: "CREATE TABLE bar (\n bar INTEGER GENERATED BY DEFAULT AS IDENTITY,\n id INT PRIMARY KEY\n);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED ALWAYS AS IDENTITY (START WITH 10)", - expect: "CREATE TABLE bar (\n bar INTEGER GENERATED ALWAYS AS IDENTITY (START WITH 10),\n id INT PRIMARY KEY\n);", - }, - ], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } - - #[test] - fn apply_create_index() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", - sql_b: "CREATE INDEX code_idx ON films (code);", - expect: "CREATE UNIQUE INDEX title_idx ON films(title);\n\nCREATE INDEX code_idx ON films(code);", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", - sql_b: "DROP INDEX title_idx;", - expect: "", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", - sql_b: "DROP INDEX title_idx;CREATE INDEX code_idx ON films (code);", - expect: "CREATE INDEX code_idx ON films(code);", - }, - ], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } - - #[test] - fn apply_alter_create_type() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + alter_table_alter_column_a { + sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar SET NOT NULL", + expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);", + }, + + alter_table_alter_column_b { + sql_a: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP NOT NULL", + expect: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY);", + }, + + alter_table_alter_column_c { + sql_a: "CREATE TABLE bar (bar TEXT NOT NULL DEFAULT 'foo', id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP DEFAULT", + expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);", + }, + + alter_table_alter_column_d { + sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE INTEGER", + expect: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY);", + }, + + alter_table_alter_column_f { + sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED BY DEFAULT AS IDENTITY", + expect: "CREATE TABLE bar (\n bar INTEGER GENERATED BY DEFAULT AS IDENTITY,\n id INT PRIMARY KEY\n);", + }, + + alter_table_alter_column_g { + sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED ALWAYS AS IDENTITY (START WITH 10)", + expect: "CREATE TABLE bar (\n bar INTEGER GENERATED ALWAYS AS IDENTITY (START WITH 10),\n id INT PRIMARY KEY\n);", + }, + + create_index_a { + sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", + sql_b: "CREATE INDEX code_idx ON films (code);", + expect: "CREATE UNIQUE INDEX title_idx ON films(title);\n\nCREATE INDEX code_idx ON films(code);", + }, + + create_index_b { + sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", + sql_b: "DROP INDEX title_idx;", + expect: "", + }, + + create_index_c { + sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", + sql_b: "DROP INDEX title_idx;CREATE INDEX code_idx ON films (code);", + expect: "CREATE INDEX code_idx ON films(code);", + }, + + alter_create_type_a { sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", sql_b: "CREATE TYPE compfoo AS (f1 int, f2 text);", expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');\n\nCREATE TYPE compfoo AS (f1 INT, f2 TEXT);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_alter_type_rename() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + alter_type_rename_a { sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", sql_b: "ALTER TYPE bug_status RENAME TO issue_status", expect: "CREATE TYPE issue_status AS ENUM ('open', 'closed');", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_alter_type_add_value() { - run_test_cases( - vec![ - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open');", - sql_b: "ALTER TYPE bug_status ADD VALUE 'new' BEFORE 'open';", - expect: "CREATE TYPE bug_status AS ENUM ('new', 'open');", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open');", - sql_b: "ALTER TYPE bug_status ADD VALUE 'closed' AFTER 'open';", - expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", - }, - TestCase { - dialect: Dialect::Generic, - sql_a: "CREATE TYPE bug_status AS ENUM ('open');", - sql_b: "ALTER TYPE bug_status ADD VALUE 'closed';", - expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", - }, - ], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + alter_type_add_value_a { + sql_a: "CREATE TYPE bug_status AS ENUM ('open');", + sql_b: "ALTER TYPE bug_status ADD VALUE 'new' BEFORE 'open';", + expect: "CREATE TYPE bug_status AS ENUM ('new', 'open');", + }, + + alter_type_add_value_b { + sql_a: "CREATE TYPE bug_status AS ENUM ('open');", + sql_b: "ALTER TYPE bug_status ADD VALUE 'closed' AFTER 'open';", + expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", + }, + + alter_type_add_value_c { + sql_a: "CREATE TYPE bug_status AS ENUM ('open');", + sql_b: "ALTER TYPE bug_status ADD VALUE 'closed';", + expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", + }, - #[test] - fn apply_alter_type_rename_value() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + alter_type_rename_value_a { sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'closed');", sql_b: "ALTER TYPE bug_status RENAME VALUE 'new' TO 'open';", expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), - ); - } + }, - #[test] - fn apply_create_extension() { - run_test_cases( - vec![TestCase { - dialect: Dialect::Generic, + create_extension_a { sql_a: "CREATE EXTENSION hstore;", sql_b: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", expect: "CREATE EXTENSION hstore;\n\nCREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), + }, + + => |ast_a, ast_b| { + ast_a.migrate(&ast_b) + } ); - } - #[test] - fn apply_create_domain() { - run_test_cases( - vec![TestCase { - dialect: Dialect::PostgreSql, + test_case!( + @dialect(Dialect::PostgreSql) + + alter_table_alter_column_e { + sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE timestamp with time zone\n USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second'", + expect: "CREATE TABLE bar (bar TIMESTAMP WITH TIME ZONE, id INT PRIMARY KEY);", + }, + + create_domain_a { sql_a: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0);", sql_b: "CREATE DOMAIN email AS VARCHAR(255) CHECK (VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$');", expect: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0);\n\nCREATE DOMAIN email AS VARCHAR(255) CHECK (\n VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n);", - }], - |ast_a, ast_b| ast_a.migrate(&ast_b), + }, + + => |ast_a, ast_b| { + ast_a.migrate(&ast_b) + } + ); + + test_case!( + @dialect(Dialect::Snowflake) + + alter_table_drop_columns { + sql_a: "CREATE TABLE bar (foo TEXT, bar TEXT, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar DROP COLUMN foo, bar", + expect: "CREATE TABLE bar (id INT PRIMARY KEY);", + }, + + => |ast_a, ast_b| { + ast_a.migrate(&ast_b) + } ); } } diff --git a/src/name_gen.rs b/src/name_gen.rs index 178ea36..3b13b17 100644 --- a/src/name_gen.rs +++ b/src/name_gen.rs @@ -126,65 +126,79 @@ mod tests { assert_eq!(actual, Some(tc.name.to_owned()), "{tc:?}"); } - fn run_test_cases(test_cases: Vec) { - test_cases.iter().for_each(run_test_case); - } + macro_rules! test_case { + ( + $( + $test_name:ident { + $( $field:ident : $value:expr ),+ $(,)? + } + ),* $(,)? + ) => { + $( + #[test] + fn $test_name() { + let test_case = TestCase { + $( $field : $value ),+ + }; - #[test] - fn test_generate_name() { - run_test_cases(vec![ - TestCase { - sql: "CREATE TABLE foo(bar TEXT);", - name: "create_foo", - }, - TestCase { - sql: "CREATE TABLE foo(bar TEXT); CREATE TABLE bar(foo TEXT);", - name: "create_foo__create_bar", - }, - TestCase { - sql: "CREATE TABLE foo(bar TEXT); CREATE TABLE bar(foo TEXT); CREATE TABLE baz(id INT); CREATE TABLE some_really_long_name(id INT);", - name: "create_foo__create_bar__create_baz__etc", - }, - TestCase { - sql: "ALTER TABLE foo DROP COLUMN bar;", - name: "alter_foo_drop_bar", - }, - TestCase { - sql: "ALTER TABLE foo ADD COLUMN bar TEXT;", - name: "alter_foo_add_bar", - }, - TestCase { - sql: "ALTER TABLE foo ALTER COLUMN bar SET DATA TYPE INT;", - name: "alter_foo_alter_bar", - }, - TestCase { - sql: "ALTER TABLE foo RENAME bar TO id;", - name: "alter_foo_rename_bar_to_id", - }, - TestCase { - sql: "ALTER TABLE foo RENAME TO bar;", - name: "rename_foo_to_bar", - }, - TestCase { - sql: "DROP TABLE foo;", - name: "drop_foo", - }, - TestCase { - sql: "CREATE TYPE status AS ENUM('one', 'two', 'three');", - name: "create_type_status", - }, - TestCase { - sql: "DROP TYPE status;", - name: "drop_type_status", - }, - TestCase { - sql: "CREATE UNIQUE INDEX title_idx ON films (title);", - name: "create_films_title_idx", - }, - TestCase { - sql: "DROP INDEX title_idx", - name: "drop_index_title_idx", - }, - ]); + run_test_case(&test_case); + } + )* + }; } + + test_case!( + create_table { + sql: "CREATE TABLE foo(bar TEXT);", + name: "create_foo", + }, + create_two_tables { + sql: "CREATE TABLE foo(bar TEXT); CREATE TABLE bar(foo TEXT);", + name: "create_foo__create_bar", + }, + create_four_tables { + sql: "CREATE TABLE foo(bar TEXT); CREATE TABLE bar(foo TEXT); CREATE TABLE baz(id INT); CREATE TABLE some_really_long_name(id INT);", + name: "create_foo__create_bar__create_baz__etc", + }, + drop_column { + sql: "ALTER TABLE foo DROP COLUMN bar;", + name: "alter_foo_drop_bar", + }, + add_column { + sql: "ALTER TABLE foo ADD COLUMN bar TEXT;", + name: "alter_foo_add_bar", + }, + alter_column { + sql: "ALTER TABLE foo ALTER COLUMN bar SET DATA TYPE INT;", + name: "alter_foo_alter_bar", + }, + rename_column { + sql: "ALTER TABLE foo RENAME bar TO id;", + name: "alter_foo_rename_bar_to_id", + }, + rename_table { + sql: "ALTER TABLE foo RENAME TO bar;", + name: "rename_foo_to_bar", + }, + drop_table { + sql: "DROP TABLE foo;", + name: "drop_foo", + }, + create_enum_type { + sql: "CREATE TYPE status AS ENUM('one', 'two', 'three');", + name: "create_type_status", + }, + drop_type { + sql: "DROP TYPE status;", + name: "drop_type_status", + }, + create_index { + sql: "CREATE UNIQUE INDEX title_idx ON films (title);", + name: "create_films_title_idx", + }, + drop_index { + sql: "DROP INDEX title_idx", + name: "drop_index_title_idx", + }, + ); } From ddc53dcbb8fbed5e9f2b29c70553019a818cd88d Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Mon, 13 Apr 2026 10:39:13 -0400 Subject: [PATCH 4/6] revert!: drop support for most sql dialects we're keeping Generic, PostgreSQL and SQLite, and can add back others as needed closes #9 --- src/lib.rs | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0e83e0f..5dc8969 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,37 +25,17 @@ pub struct ParseError(#[from] parser::ParserError); #[cfg_attr(feature = "clap", derive(clap::ValueEnum), clap(rename_all = "lower"))] #[non_exhaustive] pub enum Dialect { - Ansi, - BigQuery, - ClickHouse, - Databricks, - DuckDb, #[default] Generic, - Hive, - MsSql, - MySql, PostgreSql, - RedshiftSql, - Snowflake, SQLite, } impl Dialect { fn to_sqlparser_dialect(self) -> Box { match self { - Self::Ansi => Box::new(dialect::AnsiDialect {}), - Self::BigQuery => Box::new(dialect::BigQueryDialect {}), - Self::ClickHouse => Box::new(dialect::ClickHouseDialect {}), - Self::Databricks => Box::new(dialect::DatabricksDialect {}), - Self::DuckDb => Box::new(dialect::DuckDbDialect {}), Self::Generic => Box::new(dialect::GenericDialect {}), - Self::Hive => Box::new(dialect::HiveDialect {}), - Self::MsSql => Box::new(dialect::MsSqlDialect {}), - Self::MySql => Box::new(dialect::MySqlDialect {}), Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}), - Self::RedshiftSql => Box::new(dialect::RedshiftSqlDialect {}), - Self::Snowflake => Box::new(dialect::SnowflakeDialect {}), Self::SQLite => Box::new(dialect::SQLiteDialect {}), } } @@ -494,19 +474,5 @@ mod tests { ast_a.migrate(&ast_b) } ); - - test_case!( - @dialect(Dialect::Snowflake) - - alter_table_drop_columns { - sql_a: "CREATE TABLE bar (foo TEXT, bar TEXT, id INT PRIMARY KEY)", - sql_b: "ALTER TABLE bar DROP COLUMN foo, bar", - expect: "CREATE TABLE bar (id INT PRIMARY KEY);", - }, - - => |ast_a, ast_b| { - ast_a.migrate(&ast_b) - } - ); } } From 72cc446f8fc1fa1e8a179cdd2d263c17c66beb13 Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Sun, 12 Apr 2026 20:16:40 -0400 Subject: [PATCH 5/6] refactor: allow diverging strategies across dialects --- src/ast.rs | 24 +- src/bin/sql-schema.rs | 102 +++++- src/dialect.rs | 14 + src/diff.rs | 563 +++++------------------------ src/diff/generic.rs | 2 + src/diff/generic/statement.rs | 273 ++++++++++++++ src/diff/generic/tree.rs | 279 ++++++++++++++ src/lib.rs | 159 ++++---- src/migration.rs | 445 +++++------------------ src/migration/generic.rs | 2 + src/migration/generic/statement.rs | 329 +++++++++++++++++ src/migration/generic/tree.rs | 142 ++++++++ src/name_gen.rs | 9 +- src/parser.rs | 49 +++ src/sealed.rs | 1 + 15 files changed, 1452 insertions(+), 941 deletions(-) create mode 100644 src/dialect.rs create mode 100644 src/diff/generic.rs create mode 100644 src/diff/generic/statement.rs create mode 100644 src/diff/generic/tree.rs create mode 100644 src/migration/generic.rs create mode 100644 src/migration/generic/statement.rs create mode 100644 src/migration/generic/tree.rs create mode 100644 src/parser.rs create mode 100644 src/sealed.rs diff --git a/src/ast.rs b/src/ast.rs index 0e3218b..083a0ec 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -2,7 +2,25 @@ pub use sqlparser::ast::{ helpers::attached_token::AttachedToken, AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, AlterTypeRenameValue, ColumnDef, ColumnOption, ColumnOptionDef, CreateDomain, CreateExtension, - CreateIndex, CreateTable, DropDomain, DropExtension, GeneratedAs, Ident, ObjectName, - ObjectNamePart, ObjectType, ReferentialAction, RenameTableNameKind, Statement, - UserDefinedTypeRepresentation, + CreateIndex, CreateTable, DropDomain, DropExtension, GeneratedAs, ObjectName, ObjectNamePart, + ObjectType, ReferentialAction, RenameTableNameKind, Statement, UserDefinedTypeRepresentation, }; + +/// This is a copy of [`Statement::CreateType`]. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct CreateType { + /// Type name to create. + pub name: ObjectName, + /// Optional type representation details. + pub representation: Option, +} + +impl From for Statement { + fn from(value: CreateType) -> Self { + Statement::CreateType { + name: value.name, + representation: value.representation, + } + } +} diff --git a/src/bin/sql-schema.rs b/src/bin/sql-schema.rs index 7e67243..5bd6c04 100644 --- a/src/bin/sql-schema.rs +++ b/src/bin/sql-schema.rs @@ -1,4 +1,5 @@ use std::{ + fmt, fs::{self, File, OpenOptions}, io::{self, Write}, process::{self}, @@ -12,7 +13,7 @@ use clap::{Parser, Subcommand}; use sql_schema::{ name_gen, path_template::{PathTemplate, TemplateData, UpDown}, - Dialect, SyntaxTree, + SyntaxTree, TreeDiffer, TreeMigrator, }; #[derive(Parser, Debug)] @@ -46,6 +47,30 @@ struct SchemaCommand { dialect: Dialect, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default, clap::ValueEnum)] +#[clap(rename_all = "lower")] +#[non_exhaustive] +pub enum Dialect { + #[default] + Generic, + PostgreSql, + SQLite, +} + +impl fmt::Display for Dialect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // NOTE: this must match how clap::ValueEnum displays variants + write!( + f, + "{}", + format!("{self:?}") + .to_ascii_lowercase() + .split('-') + .collect::() + ) + } +} + #[derive(Parser, Debug)] struct MigrationCommand { /// path to schema file @@ -102,15 +127,44 @@ fn main() { } } +macro_rules! match_dialect { + ( $dialect:expr, $expr:expr ) => { + match $dialect { + Dialect::Generic => { + let dialect = sql_schema::dialect::Generic::default(); + $expr(dialect) + } + Dialect::PostgreSql => { + let dialect = sql_schema::dialect::PostgreSQL::default(); + $expr(dialect) + } + Dialect::SQLite => { + let dialect = sql_schema::dialect::SQLite::default(); + $expr(dialect) + } + } + }; +} + /// create or update schema file from migrations fn run_schema(command: SchemaCommand) -> anyhow::Result<()> { ensure_schema_file(&command.schema_path)?; ensure_migration_dir(&command.migrations_dir)?; - let (migrations, _) = parse_migrations(command.dialect, &command.migrations_dir)?; - let schema = parse_sql_file(command.dialect, &command.schema_path)?; + match_dialect!(&command.dialect, |dialect| run_schema_inner( + dialect, command + )) +} + +fn run_schema_inner(dialect: D, command: SchemaCommand) -> anyhow::Result<()> +where + D: TreeDiffer + TreeMigrator + sql_schema::Parse, +{ + let (migrations, _) = parse_migrations(dialect.clone(), &command.migrations_dir)?; + let schema = parse_sql_file(dialect, &command.schema_path)?; + let diff = schema.diff(&migrations)?.unwrap_or_else(SyntaxTree::empty); - let schema = schema.migrate(&diff)?.unwrap_or_else(SyntaxTree::empty); + let schema = schema.migrate(&diff)?; eprintln!("writing {}", command.schema_path); OpenOptions::new() .write(true) @@ -126,9 +180,18 @@ fn run_migration(command: MigrationCommand) -> anyhow::Result<()> { ensure_schema_file(&command.schema_path)?; ensure_migration_dir(&command.migrations_dir)?; - let (migrations, opts) = parse_migrations(command.dialect, &command.migrations_dir)?; + match_dialect!(&command.dialect, |dialect| run_migration_inner( + dialect, command + )) +} + +fn run_migration_inner(dialect: D, command: MigrationCommand) -> anyhow::Result<()> +where + D: TreeDiffer + TreeMigrator + sql_schema::Parse, +{ + let (migrations, opts) = parse_migrations(dialect.clone(), &command.migrations_dir)?; let opts = opts.reconcile(&command); - let schema = parse_sql_file(command.dialect, &command.schema_path)?; + let schema = parse_sql_file(dialect, &command.schema_path)?; match migrations.diff(&schema)? { Some(up_migration) => { let name = if opts.num_migrations == 0 { @@ -191,7 +254,7 @@ fn run_migration(command: MigrationCommand) -> anyhow::Result<()> { } } -fn write_migration(migration: SyntaxTree, path: &Utf8Path) -> anyhow::Result<()> { +fn write_migration(migration: SyntaxTree, path: &Utf8Path) -> anyhow::Result<()> { eprintln!("writing {path}"); if let Some(parent) = path.parent() { eprintln!("creating {parent}"); @@ -228,20 +291,23 @@ fn ensure_migration_dir(dir: &Utf8Path) -> anyhow::Result<()> { Ok(()) } -fn parse_sql_file(dialect: Dialect, path: &Utf8Path) -> anyhow::Result { +fn parse_sql_file(dialect: Dialect, path: &Utf8Path) -> anyhow::Result> +where + Dialect: sql_schema::Parse, +{ let data = fs::read_to_string(path)?; - SyntaxTree::builder() - .dialect(dialect) - .sql(data.as_str()) - .build() - .context(format!("path: {path}")) + let data = data.as_str(); + SyntaxTree::parse(dialect, data).context(format!("path: {path}")) } /// builds a [SyntaxTree] by applying each migration in order -fn parse_migrations( +fn parse_migrations( dialect: Dialect, dir: &Utf8Path, -) -> anyhow::Result<(SyntaxTree, MigrationOptions)> { +) -> anyhow::Result<(SyntaxTree, MigrationOptions)> +where + Dialect: TreeDiffer + TreeMigrator + sql_schema::Parse, +{ fn process_dir_entry( entry: io::Result, ) -> anyhow::Result>> { @@ -308,10 +374,8 @@ fn parse_migrations( .iter() .try_fold(SyntaxTree::empty(), |schema, path| -> anyhow::Result<_> { eprintln!("parsing {path}"); - let migration = parse_sql_file(dialect, path)?; - let schema = schema - .migrate(&migration)? - .unwrap_or_else(SyntaxTree::empty); + let migration = parse_sql_file(dialect.clone(), path)?; + let schema = schema.migrate(&migration)?; Ok(schema) })?; Ok((tree, opts)) diff --git a/src/dialect.rs b/src/dialect.rs new file mode 100644 index 0000000..1be8a11 --- /dev/null +++ b/src/dialect.rs @@ -0,0 +1,14 @@ +use crate::sealed::Sealed; + +#[derive(Debug, Default, Clone)] +pub struct Generic; + +#[derive(Debug, Default, Clone)] +pub struct PostgreSQL; + +#[derive(Debug, Default, Clone)] +pub struct SQLite; + +impl Sealed for Generic {} +impl Sealed for PostgreSQL {} +impl Sealed for SQLite {} diff --git a/src/diff.rs b/src/diff.rs index a4ebb84..92b0258 100644 --- a/src/diff.rs +++ b/src/diff.rs @@ -1,15 +1,16 @@ -use std::{cmp::Ordering, collections::HashSet, fmt}; +use std::fmt; use bon::bon; use thiserror::Error; -use crate::ast::{ - AlterTable, AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, - AlterTypeOperation, AttachedToken, CreateDomain, CreateExtension, CreateIndex, CreateTable, - DropDomain, DropExtension, Ident, ObjectName, ObjectType, Statement, - UserDefinedTypeRepresentation, +use crate::{ + ast::{CreateDomain, CreateExtension, CreateIndex, CreateTable, CreateType, Statement}, + dialect::{Generic, PostgreSQL, SQLite}, + sealed::Sealed, }; +pub mod generic; + #[derive(Error, Debug)] pub struct DiffError { kind: DiffErrorKind, @@ -65,501 +66,103 @@ pub enum DiffErrorKind { pub type Result = std::result::Result; -pub(crate) trait Diff: Sized { - type Diff; - - fn diff(&self, other: &Self) -> Result; -} - -impl Diff for Vec { - type Diff = Option>; - - fn diff(&self, other: &Self) -> Result { - let res = self - .iter() - .filter_map(|sa| { - match sa { - // CreateTable: compare against another CreateTable with the same name - // TODO: handle renames (e.g. use comments to tag a previous name for a table in a schema) - Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other), - Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other), - Statement::CreateType { name, .. } => { - find_and_compare_create_type(sa, name, other) - } - Statement::CreateExtension(CreateExtension { - name, - if_not_exists, - cascade, - .. - }) => { - find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other) - } - Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other), - _ => Err(DiffError::builder() - .kind(DiffErrorKind::NotImplemented) - .statement_a(sa.clone()) - .build()), - } - .transpose() - }) - // find resources that are in `other` but not in `self` - .chain(other.iter().filter_map(|sb| { - match sb { - Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa { - Statement::CreateTable(a) => a.name == b.name, - _ => false, - })), - Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa { - Statement::CreateIndex(a) => a.name == b.name, - _ => false, - })), - Statement::CreateType { name: b_name, .. } => { - Ok(self.iter().find(|sa| match sa { - Statement::CreateType { name: a_name, .. } => a_name == b_name, - _ => false, - })) - } - Statement::CreateExtension(CreateExtension { name: b_name, .. }) => { - Ok(self.iter().find(|sa| match sa { - Statement::CreateExtension(CreateExtension { - name: a_name, .. - }) => a_name == b_name, - _ => false, - })) - } - Statement::CreateDomain(b) => Ok(self.iter().find(|sa| match sa { - Statement::CreateDomain(a) => a.name == b.name, - _ => false, - })), - _ => Err(DiffError::builder() - .kind(DiffErrorKind::NotImplemented) - .statement_a(sb.clone()) - .build()), - } - .transpose() - // return the statement if it's not in `self` - .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None) - })) - .collect::, _>>()? - .into_iter() - .flatten() - .collect::>(); - - if res.is_empty() { - Ok(None) - } else { - Ok(Some(res)) - } +pub trait TreeDiffer: StatementDiffer + Sealed { + fn diff_tree(&self, a: &[Statement], b: &[Statement]) -> Result>> { + generic::tree::tree_diff(self, a, b) } -} - -fn find_and_compare( - sa: &Statement, - other: &[Statement], - match_fn: MF, - drop_fn: DF, -) -> Result>> -where - MF: Fn(&&Statement) -> bool, - DF: Fn() -> Result>>, -{ - other.iter().find(match_fn).map_or_else( - // drop the statement if it wasn't found in `other` - drop_fn, - // otherwise diff the two statements - |sb| sa.diff(sb), - ) -} - -fn find_and_compare_create_table( - sa: &Statement, - a: &CreateTable, - other: &[Statement], -) -> Result>> { - find_and_compare( - sa, - other, - |sb| match sb { - Statement::CreateTable(b) => a.name == b.name, - _ => false, - }, - || { - Ok(Some(vec![Statement::Drop { - object_type: crate::ast::ObjectType::Table, - if_exists: a.if_not_exists, - names: vec![a.name.clone()], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, - }])) - }, - ) -} -fn find_and_compare_create_index( - sa: &Statement, - a: &CreateIndex, - other: &[Statement], -) -> Result>> { - find_and_compare( - sa, - other, - |sb| match sb { - Statement::CreateIndex(b) => a.name == b.name, - _ => false, - }, - || { - let name = a.name.clone().ok_or_else(|| { - DiffError::builder() - .kind(DiffErrorKind::DropUnnamedIndex) - .statement_a(sa.clone()) - .build() - })?; - - Ok(Some(vec![Statement::Drop { - object_type: crate::ast::ObjectType::Index, - if_exists: a.if_not_exists, - names: vec![name], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, - }])) - }, - ) -} - -fn find_and_compare_create_type( - sa: &Statement, - a_name: &ObjectName, - other: &[Statement], -) -> Result>> { - find_and_compare( - sa, - other, - |sb| match sb { - Statement::CreateType { name: b_name, .. } => a_name == b_name, - _ => false, - }, - || { - Ok(Some(vec![Statement::Drop { - object_type: crate::ast::ObjectType::Type, - if_exists: false, - names: vec![a_name.clone()], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, - }])) - }, - ) -} + fn find_and_compare_create_table( + &self, + sa: &Statement, + a: &CreateTable, + b: &[Statement], + ) -> Result>> { + generic::tree::find_and_compare_create_table(self, sa, a, b) + } -fn find_and_compare_create_extension( - sa: &Statement, - a_name: &Ident, - if_not_exists: bool, - cascade: bool, - other: &[Statement], -) -> Result>> { - find_and_compare( - sa, - other, - |sb| match sb { - Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name, - _ => false, - }, - || { - Ok(Some(vec![Statement::DropExtension(DropExtension { - names: vec![a_name.clone()], - if_exists: if_not_exists, - cascade_or_restrict: if cascade { - Some(crate::ast::ReferentialAction::Cascade) - } else { - None - }, - })])) - }, - ) -} + fn find_and_compare_create_index( + &self, + sa: &Statement, + a: &CreateIndex, + b: &[Statement], + ) -> Result>> { + generic::tree::find_and_compare_create_index(self, sa, a, b) + } -fn find_and_compare_create_domain( - orig: &Statement, - domain: &CreateDomain, - other: &[Statement], -) -> Result>> { - let res = other - .iter() - .find(|sb| match sb { - Statement::CreateDomain(b) => b.name == domain.name, - _ => false, - }) - .map(|sb| orig.diff(sb)) - .transpose()? - .flatten(); - Ok(res) -} + fn find_and_compare_create_type( + &self, + sa: &Statement, + a: &CreateType, + b: &[Statement], + ) -> Result>> { + generic::tree::find_and_compare_create_type(self, sa, a, b) + } -impl Diff for Statement { - type Diff = Option>; + fn find_and_compare_create_extension( + &self, + sa: &Statement, + a: &CreateExtension, + b: &[Statement], + ) -> Result>> { + generic::tree::find_and_compare_create_extension(self, sa, a, b) + } - fn diff(&self, other: &Self) -> Result { - match self { - Self::CreateTable(a) => match other { - Self::CreateTable(b) => Ok(compare_create_table(a, b)), - _ => Ok(None), - }, - Self::CreateIndex(a) => match other { - Self::CreateIndex(b) => compare_create_index(a, b), - _ => Ok(None), - }, - Self::CreateType { - name: a_name, - representation: a_rep, - } => match other { - Self::CreateType { - name: b_name, - representation: b_rep, - } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep), - _ => Ok(None), - }, - Self::CreateDomain(a) => match other { - Self::CreateDomain(b) => Ok(compare_create_domain(a, b)), - _ => Ok(None), - }, - _ => Err(DiffError::builder() - .kind(DiffErrorKind::NotImplemented) - .statement_a(self.clone()) - .statement_b(other.clone()) - .build()), - } + fn find_and_compare_create_domain( + &self, + sa: &Statement, + a: &CreateDomain, + b: &[Statement], + ) -> Result>> { + generic::tree::find_and_compare_create_domain(self, sa, a, b) } } -fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option> { - if a == b { - return None; - } +impl TreeDiffer for Generic {} - let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.value.clone()).collect(); - let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.value.clone()).collect(); +impl TreeDiffer for PostgreSQL {} - let operations: Vec<_> = a - .columns - .iter() - .filter_map(|ac| { - if b_column_names.contains(&ac.name.value) { - None - } else { - // drop column if it only exists in `a` - Some(AlterTableOperation::DropColumn { - column_names: vec![ac.name.clone()], - if_exists: a.if_not_exists, - drop_behavior: None, - has_column_keyword: true, - }) - } - }) - .chain(b.columns.iter().filter_map(|bc| { - if a_column_names.contains(&bc.name.value) { - None - } else { - // add the column if it only exists in `b` - Some(AlterTableOperation::AddColumn { - column_keyword: true, - if_not_exists: a.if_not_exists, - column_def: bc.clone(), - column_position: None, - }) - } - })) - .collect(); +impl TreeDiffer for SQLite {} - if operations.is_empty() { - return None; +pub trait StatementDiffer: fmt::Debug + Default + Clone + Sized + Sealed { + fn diff(&self, sa: &Statement, sb: &Statement) -> Result>> { + generic::statement::diff(self, sa, sb) } - Some(vec![Statement::AlterTable(AlterTable { - table_type: None, - name: a.name.clone(), - if_exists: a.if_not_exists, - only: false, - operations, - location: None, - on_cluster: a.on_cluster.clone(), - end_token: AttachedToken::empty(), - })]) -} - -fn compare_create_index(a: &CreateIndex, b: &CreateIndex) -> Result>> { - if a == b { - return Ok(None); + fn compare_create_table( + &self, + a: &CreateTable, + b: &CreateTable, + ) -> Result>> { + generic::statement::compare_create_table(a, b) } - if a.name.is_none() || b.name.is_none() { - return Err(DiffError::builder() - .kind(DiffErrorKind::CompareUnnamedIndex) - .statement_a(Statement::CreateIndex(a.clone())) - .statement_b(Statement::CreateIndex(b.clone())) - .build()); + fn compare_create_index( + &self, + a: &CreateIndex, + b: &CreateIndex, + ) -> Result>> { + generic::statement::compare_create_index(a, b) } - let name = a.name.clone().unwrap(); - - Ok(Some(vec![ - Statement::Drop { - object_type: ObjectType::Index, - if_exists: a.if_not_exists, - names: vec![name], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, - }, - Statement::CreateIndex(b.clone()), - ])) -} -fn compare_create_type( - a: &Statement, - a_name: &ObjectName, - a_rep: &Option, - b: &Statement, - b_name: &ObjectName, - b_rep: &Option, -) -> Result>> { - if a_name == b_name && a_rep == b_rep { - return Ok(None); + fn compare_create_type( + &self, + a: &CreateType, + b: &CreateType, + ) -> Result>> { + generic::statement::compare_create_type(a, b) } - let operations = match a_rep { - Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match b_rep { - Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => { - match a_labels.len().cmp(&b_labels.len()) { - Ordering::Equal => { - let rename_labels: Vec<_> = a_labels - .iter() - .zip(b_labels.iter()) - .filter_map(|(a, b)| { - if a == b { - None - } else { - Some(AlterTypeOperation::RenameValue( - crate::ast::AlterTypeRenameValue { - from: a.clone(), - to: b.clone(), - }, - )) - } - }) - .collect(); - rename_labels - } - Ordering::Less => { - let mut a_labels_iter = a_labels.iter().peekable(); - let mut operations = Vec::new(); - let mut prev = None; - for b in b_labels { - match a_labels_iter.peek() { - Some(a) => { - let a = *a; - if a == b { - prev = Some(a); - a_labels_iter.next(); - continue; - } - - let position = match prev { - Some(a) => AlterTypeAddValuePosition::After(a.clone()), - None => AlterTypeAddValuePosition::Before(a.clone()), - }; - - prev = Some(b); - operations.push(AlterTypeOperation::AddValue( - AlterTypeAddValue { - if_not_exists: false, - value: b.clone(), - position: Some(position), - }, - )); - } - None => { - if a_labels.contains(b) { - continue; - } - // labels occuring after all existing ones get added to the end - operations.push(AlterTypeOperation::AddValue( - AlterTypeAddValue { - if_not_exists: false, - value: b.clone(), - position: None, - }, - )); - } - } - } - operations - } - _ => { - return Err(DiffError::builder() - .kind(DiffErrorKind::RemoveEnumLabel) - .statement_a(a.clone()) - .statement_b(b.clone()) - .build()); - } - } - } - _ => { - // TODO: DROP and CREATE type - return Err(DiffError::builder() - .kind(DiffErrorKind::NotImplemented) - .statement_a(a.clone()) - .statement_b(b.clone()) - .build()); - } - }, - _ => { - // TODO: handle diffing composite attributes for CREATE TYPE - return Err(DiffError::builder() - .kind(DiffErrorKind::NotImplemented) - .statement_a(a.clone()) - .statement_b(b.clone()) - .build()); - } - }; - - if operations.is_empty() { - return Ok(None); + fn compare_create_domain( + &self, + a: &CreateDomain, + b: &CreateDomain, + ) -> Result>> { + generic::statement::compare_create_domain(a, b) } - - Ok(Some( - operations - .into_iter() - .map(|operation| { - Statement::AlterType(AlterType { - name: a_name.clone(), - operation, - }) - }) - .collect(), - )) } -fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Option> { - if a == b { - return None; - } +impl StatementDiffer for Generic {} - Some(vec![ - Statement::DropDomain(DropDomain { - if_exists: true, - name: a.name.clone(), - drop_behavior: None, - }), - Statement::CreateDomain(b.clone()), - ]) -} +impl StatementDiffer for PostgreSQL {} + +impl StatementDiffer for SQLite {} diff --git a/src/diff/generic.rs b/src/diff/generic.rs new file mode 100644 index 0000000..5b31d1b --- /dev/null +++ b/src/diff/generic.rs @@ -0,0 +1,2 @@ +pub mod statement; +pub mod tree; diff --git a/src/diff/generic/statement.rs b/src/diff/generic/statement.rs new file mode 100644 index 0000000..f30186b --- /dev/null +++ b/src/diff/generic/statement.rs @@ -0,0 +1,273 @@ +use std::{cmp::Ordering, collections::HashSet}; + +use crate::{ + ast::{ + AlterTable, AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, + AlterTypeOperation, AlterTypeRenameValue, AttachedToken, CreateDomain, CreateIndex, + CreateTable, CreateType, DropDomain, ObjectType, Statement, UserDefinedTypeRepresentation, + }, + diff::{DiffError, DiffErrorKind, Result, StatementDiffer}, +}; + +pub fn diff( + dialect: &Dialect, + sa: &Statement, + sb: &Statement, +) -> Result>> +where + Dialect: StatementDiffer, +{ + match sa { + Statement::CreateTable(a) => match sb { + Statement::CreateTable(b) => dialect.compare_create_table(a, b), + _ => Ok(None), + }, + Statement::CreateIndex(a) => match sb { + Statement::CreateIndex(b) => dialect.compare_create_index(a, b), + _ => Ok(None), + }, + Statement::CreateType { + name: a_name, + representation: a_rep, + } => match sb { + Statement::CreateType { + name: b_name, + representation: b_rep, + } => dialect.compare_create_type( + &CreateType { + name: a_name.clone(), + representation: a_rep.clone(), + }, + &CreateType { + name: b_name.clone(), + representation: b_rep.clone(), + }, + ), + _ => Ok(None), + }, + Statement::CreateDomain(a) => match sb { + Statement::CreateDomain(b) => dialect.compare_create_domain(a, b), + _ => Ok(None), + }, + _ => Err(DiffError::builder() + .kind(DiffErrorKind::NotImplemented) + .statement_a(sa.clone()) + .statement_b(sb.clone()) + .build()), + } +} + +pub fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Result>> { + if a == b { + return Ok(None); + } + + let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.value.clone()).collect(); + let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.value.clone()).collect(); + + let operations: Vec<_> = a + .columns + .iter() + .filter_map(|ac| { + if b_column_names.contains(&ac.name.value) { + None + } else { + // drop column if it only exists in `a` + Some(AlterTableOperation::DropColumn { + column_names: vec![ac.name.clone()], + if_exists: a.if_not_exists, + drop_behavior: None, + has_column_keyword: true, + }) + } + }) + .chain(b.columns.iter().filter_map(|bc| { + if a_column_names.contains(&bc.name.value) { + None + } else { + // add the column if it only exists in `b` + Some(AlterTableOperation::AddColumn { + column_keyword: true, + if_not_exists: a.if_not_exists, + column_def: bc.clone(), + column_position: None, + }) + } + })) + .collect(); + + if operations.is_empty() { + return Ok(None); + } + + Ok(Some(vec![Statement::AlterTable(AlterTable { + table_type: None, + name: a.name.clone(), + if_exists: a.if_not_exists, + only: false, + operations, + location: None, + on_cluster: a.on_cluster.clone(), + end_token: AttachedToken::empty(), + })])) +} + +pub fn compare_create_index(a: &CreateIndex, b: &CreateIndex) -> Result>> { + if a == b { + return Ok(None); + } + + if a.name.is_none() || b.name.is_none() { + Err(DiffError::builder() + .kind(DiffErrorKind::CompareUnnamedIndex) + .statement_a(Statement::CreateIndex(a.clone())) + .statement_b(Statement::CreateIndex(b.clone())) + .build())?; + } + let name = a.name.clone().unwrap(); + + Ok(Some(vec![ + Statement::Drop { + object_type: ObjectType::Index, + if_exists: a.if_not_exists, + names: vec![name], + cascade: false, + restrict: false, + purge: false, + temporary: false, + table: None, + }, + Statement::CreateIndex(b.clone()), + ])) +} + +pub fn compare_create_type(a: &CreateType, b: &CreateType) -> Result>> { + if a == b { + return Ok(None); + } + + let operations = match &a.representation { + Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match &b.representation { + Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => { + match a_labels.len().cmp(&b_labels.len()) { + Ordering::Equal => { + let rename_labels: Vec<_> = a_labels + .iter() + .zip(b_labels.iter()) + .filter_map(|(a, b)| { + if a == b { + None + } else { + Some(AlterTypeOperation::RenameValue(AlterTypeRenameValue { + from: a.clone(), + to: b.clone(), + })) + } + }) + .collect(); + rename_labels + } + Ordering::Less => { + let mut a_labels_iter = a_labels.iter().peekable(); + let mut operations = Vec::new(); + let mut prev = None; + for b in b_labels { + match a_labels_iter.peek() { + Some(a) => { + let a = *a; + if a == b { + prev = Some(a); + a_labels_iter.next(); + continue; + } + + let position = match prev { + Some(a) => AlterTypeAddValuePosition::After(a.clone()), + None => AlterTypeAddValuePosition::Before(a.clone()), + }; + + prev = Some(b); + operations.push(AlterTypeOperation::AddValue( + AlterTypeAddValue { + if_not_exists: false, + value: b.clone(), + position: Some(position), + }, + )); + } + None => { + if a_labels.contains(b) { + continue; + } + // labels occuring after all existing ones get added to the end + operations.push(AlterTypeOperation::AddValue( + AlterTypeAddValue { + if_not_exists: false, + value: b.clone(), + position: None, + }, + )); + } + } + } + operations + } + _ => { + return Err(DiffError::builder() + .kind(DiffErrorKind::RemoveEnumLabel) + .statement_a(a.clone()) + .statement_b(b.clone()) + .build())?; + } + } + } + _ => { + // TODO: DROP and CREATE type + return Err(DiffError::builder() + .kind(DiffErrorKind::NotImplemented) + .statement_a(a.clone()) + .statement_b(b.clone()) + .build())?; + } + }, + _ => { + // TODO: handle diffing composite attributes for CREATE TYPE + return Err(DiffError::builder() + .kind(DiffErrorKind::NotImplemented) + .statement_a(a.clone()) + .statement_b(b.clone()) + .build())?; + } + }; + + if operations.is_empty() { + return Ok(None); + } + + Ok(Some( + operations + .into_iter() + .map(|operation| { + Statement::AlterType(AlterType { + name: a.name.clone(), + operation, + }) + }) + .collect(), + )) +} + +pub fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Result>> { + if a == b { + return Ok(None); + } + + Ok(Some(vec![ + Statement::DropDomain(DropDomain { + if_exists: true, + name: a.name.clone(), + drop_behavior: None, + }), + Statement::CreateDomain(b.clone()), + ])) +} diff --git a/src/diff/generic/tree.rs b/src/diff/generic/tree.rs new file mode 100644 index 0000000..1ed163a --- /dev/null +++ b/src/diff/generic/tree.rs @@ -0,0 +1,279 @@ +use crate::{ + ast::{ + CreateDomain, CreateExtension, CreateIndex, CreateTable, CreateType, DropDomain, + DropExtension, Statement, + }, + diff::{DiffError, DiffErrorKind, Result, StatementDiffer, TreeDiffer}, +}; + +pub fn tree_diff( + dialect: &Dialect, + a: &[Statement], + b: &[Statement], +) -> Result>> +where + Dialect: TreeDiffer, +{ + let res = a + .iter() + .filter_map(|sa| { + match sa { + // CreateTable: compare against another CreateTable with the same name + // TODO: handle renames (e.g. use comments to tag a previous name for a table in a schema) + Statement::CreateTable(a) => dialect.find_and_compare_create_table(sa, a, b), + Statement::CreateIndex(a) => dialect.find_and_compare_create_index(sa, a, b), + Statement::CreateType { + name, + representation, + } => dialect.find_and_compare_create_type( + sa, + &CreateType { + name: name.clone(), + representation: representation.clone(), + }, + b, + ), + Statement::CreateExtension(sb) => { + dialect.find_and_compare_create_extension(sa, sb, b) + } + Statement::CreateDomain(a) => dialect.find_and_compare_create_domain(sa, a, b), + _ => Err(DiffError::builder() + .kind(DiffErrorKind::NotImplemented) + .statement_a(sa.clone()) + .build()), + } + .transpose() + }) + // find resources that are in `other` but not in `a` + .chain(b.iter().filter_map(|sb| { + match sb { + Statement::CreateTable(b) => Ok(a.iter().find(|sa| match sa { + Statement::CreateTable(a) => a.name == b.name, + _ => false, + })), + Statement::CreateIndex(b) => Ok(a.iter().find(|sa| match sa { + Statement::CreateIndex(a) => a.name == b.name, + _ => false, + })), + Statement::CreateType { name: b_name, .. } => Ok(a.iter().find(|sa| match sa { + Statement::CreateType { name: a_name, .. } => a_name == b_name, + _ => false, + })), + Statement::CreateExtension(CreateExtension { name: b_name, .. }) => { + Ok(a.iter().find(|sa| match sa { + Statement::CreateExtension(CreateExtension { name: a_name, .. }) => { + a_name == b_name + } + _ => false, + })) + } + Statement::CreateDomain(b) => Ok(a.iter().find(|sa| match sa { + Statement::CreateDomain(a) => a.name == b.name, + _ => false, + })), + _ => Err(DiffError::builder() + .kind(DiffErrorKind::NotImplemented) + .statement_a(sb.clone()) + .build()), + } + .transpose() + // return the statement if it's not in `self` + .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None) + })) + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); + + if res.is_empty() { + Ok(None) + } else { + Ok(Some(res)) + } +} + +fn find_and_compare( + dialect: &Dialect, + sa: &Statement, + b: &[Statement], + match_fn: MF, + drop_fn: DF, +) -> Result>> +where + Dialect: StatementDiffer, + MF: Fn(&&Statement) -> bool, + DF: Fn() -> Result>>, +{ + b.iter().find(match_fn).map_or_else( + // drop the statement if it wasn't found in `other` + drop_fn, + // otherwise diff the two statements + |sb| StatementDiffer::diff(dialect, sa, sb), + ) +} + +pub fn find_and_compare_create_table( + dialect: &Dialect, + sa: &Statement, + a: &CreateTable, + b: &[Statement], +) -> Result>> +where + Dialect: StatementDiffer, +{ + find_and_compare( + dialect, + sa, + b, + |sb| match sb { + Statement::CreateTable(b) => a.name == b.name, + _ => false, + }, + || { + Ok(Some(vec![Statement::Drop { + object_type: crate::ast::ObjectType::Table, + if_exists: a.if_not_exists, + names: vec![a.name.clone()], + cascade: false, + restrict: false, + purge: false, + temporary: false, + table: None, + }])) + }, + ) +} + +pub fn find_and_compare_create_index( + dialect: &Dialect, + sa: &Statement, + a: &CreateIndex, + b: &[Statement], +) -> Result>> +where + Dialect: StatementDiffer, +{ + find_and_compare( + dialect, + sa, + b, + |sb| match sb { + Statement::CreateIndex(b) => a.name == b.name, + _ => false, + }, + || { + let name = a.name.clone().ok_or_else(|| { + DiffError::builder() + .kind(DiffErrorKind::DropUnnamedIndex) + .statement_a(sa.clone()) + .build() + })?; + + Ok(Some(vec![Statement::Drop { + object_type: crate::ast::ObjectType::Index, + if_exists: a.if_not_exists, + names: vec![name], + cascade: false, + restrict: false, + purge: false, + temporary: false, + table: None, + }])) + }, + ) +} + +pub fn find_and_compare_create_type( + dialect: &Dialect, + sa: &Statement, + a: &CreateType, + b: &[Statement], +) -> Result>> +where + Dialect: StatementDiffer, +{ + let a_name = &a.name; + find_and_compare( + dialect, + sa, + b, + |sb| match sb { + Statement::CreateType { name: b_name, .. } => a_name == b_name, + _ => false, + }, + || { + Ok(Some(vec![Statement::Drop { + object_type: crate::ast::ObjectType::Type, + if_exists: false, + names: vec![a_name.clone()], + cascade: false, + restrict: false, + purge: false, + temporary: false, + table: None, + }])) + }, + ) +} + +pub fn find_and_compare_create_extension( + dialect: &Dialect, + sa: &Statement, + a: &CreateExtension, + b: &[Statement], +) -> Result>> +where + Dialect: StatementDiffer, +{ + let a_name = &a.name; + let if_not_exists = a.if_not_exists; + let cascade = a.cascade; + + find_and_compare( + dialect, + sa, + b, + |sb| match sb { + Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name, + _ => false, + }, + || { + Ok(Some(vec![Statement::DropExtension(DropExtension { + names: vec![a_name.clone()], + if_exists: if_not_exists, + cascade_or_restrict: if cascade { + Some(crate::ast::ReferentialAction::Cascade) + } else { + None + }, + })])) + }, + ) +} + +pub fn find_and_compare_create_domain( + dialect: &Dialect, + sa: &Statement, + a: &CreateDomain, + b: &[Statement], +) -> Result>> +where + Dialect: StatementDiffer, +{ + find_and_compare( + dialect, + sa, + b, + |sb| match sb { + Statement::CreateDomain(b) => b.name == a.name, + _ => false, + }, + || { + Ok(Some(vec![Statement::DropDomain(DropDomain { + name: a.name.clone(), + if_exists: false, + drop_behavior: None, + })])) + }, + ) +} diff --git a/src/lib.rs b/src/lib.rs index 5dc8969..4f06f26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,93 +1,80 @@ use std::fmt; -use bon::bon; -use sqlparser::{ - dialect::{self}, - parser::{self, Parser}, -}; -use thiserror::Error; +use self::ast::Statement; -use ast::Statement; -use diff::Diff; -use migration::Migrate; +pub use self::{ + diff::TreeDiffer, + migration::TreeMigrator, + parser::{Parse, ParseError}, +}; mod ast; +pub mod dialect; mod diff; mod migration; pub mod name_gen; +mod parser; pub mod path_template; +mod sealed; -#[derive(Error, Debug)] -#[error("Oops, we couldn't parse that!")] -pub struct ParseError(#[from] parser::ParserError); - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] -#[cfg_attr(feature = "clap", derive(clap::ValueEnum), clap(rename_all = "lower"))] -#[non_exhaustive] -pub enum Dialect { - #[default] - Generic, - PostgreSql, - SQLite, +#[derive(Debug, Clone)] +pub struct SyntaxTree { + dialect: Dialect, + pub(crate) tree: Vec, } -impl Dialect { - fn to_sqlparser_dialect(self) -> Box { - match self { - Self::Generic => Box::new(dialect::GenericDialect {}), - Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}), - Self::SQLite => Box::new(dialect::SQLiteDialect {}), +impl SyntaxTree { + pub fn empty() -> Self { + Self { + dialect: Default::default(), + tree: Vec::with_capacity(0), } } } -impl fmt::Display for Dialect { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // NOTE: this must match how clap::ValueEnum displays variants - write!( - f, - "{}", - format!("{self:?}") - .to_ascii_lowercase() - .split('-') - .collect::() - ) - } -} - -#[derive(Debug, Clone)] -pub struct SyntaxTree(pub(crate) Vec); - -#[bon] -impl SyntaxTree { - #[builder] - pub fn new<'a>(dialect: Option, sql: impl Into<&'a str>) -> Result { - let dialect = dialect.unwrap_or_default().to_sqlparser_dialect(); - let ast = Parser::parse_sql(dialect.as_ref(), sql.into())?; - Ok(Self(ast)) - } - - pub fn empty() -> Self { - Self(vec![]) +impl SyntaxTree +where + Dialect: Parse, +{ + pub fn parse<'a>(dialect: Dialect, sql: impl Into<&'a str>) -> Result { + let tree = dialect.parse_sql::(sql)?; + Ok(Self { dialect, tree }) } } pub use diff::DiffError; pub use migration::MigrateError; -impl SyntaxTree { - pub fn diff(&self, other: &SyntaxTree) -> Result, DiffError> { - Ok(Diff::diff(&self.0, &other.0)?.map(Self)) +impl SyntaxTree +where + Dialect: TreeDiffer, +{ + pub fn diff(&self, other: &SyntaxTree) -> Result, DiffError> { + Ok( + TreeDiffer::diff_tree(&self.dialect, &self.tree, &other.tree)?.map(|tree| Self { + dialect: self.dialect.clone(), + tree, + }), + ) } +} - pub fn migrate(self, other: &SyntaxTree) -> Result, MigrateError> { - Ok(Migrate::migrate(self.0, &other.0)?.map(Self)) +impl SyntaxTree +where + Dialect: TreeMigrator, +{ + pub fn migrate(self, other: &SyntaxTree) -> Result { + let tree = TreeMigrator::migrate_tree(&self.dialect, self.tree, &other.tree)?; + Ok(Self { + dialect: self.dialect.clone(), + tree, + }) } } -impl fmt::Display for SyntaxTree { +impl fmt::Display for SyntaxTree { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut iter = self.0.iter().peekable(); + let mut iter = self.tree.iter().peekable(); while let Some(s) = iter.next() { let formatted = sqlformat::format( format!("{s};").as_str(), @@ -105,11 +92,12 @@ impl fmt::Display for SyntaxTree { #[cfg(test)] mod tests { + use super::dialect::Generic; use super::*; macro_rules! test_case { ( - @dialect($dialect:path) $(,)? + @dialect($dialect:ty) $(,)? $( $test_name:ident { $( $field:ident : $value:literal ),+ $(,)? } @@ -120,8 +108,10 @@ mod tests { $( #[test] fn $test_name() { - let test_case = TestCase { - dialect: $dialect, + let dialect = <$dialect>::default(); + + let test_case: TestCase<$dialect> = TestCase { + dialect: dialect.clone(), $( $field : $value ),+ }; @@ -132,32 +122,23 @@ mod tests { } #[derive(Debug)] - struct TestCase { + struct TestCase { dialect: Dialect, sql_a: &'static str, sql_b: &'static str, expect: &'static str, } - fn run_test_case(tc: &TestCase, testfn: F) + fn run_test_case(tc: &TestCase, testfn: F) where - F: Fn(SyntaxTree, SyntaxTree) -> Result, E>, + Dialect: Parse + TreeDiffer, E: std::error::Error, + F: Fn(SyntaxTree, SyntaxTree) -> Result>, E>, { - let ast_a = SyntaxTree::builder() - .dialect(tc.dialect) - .sql(tc.sql_a) - .build() - .unwrap(); - let ast_b = SyntaxTree::builder() - .dialect(tc.dialect) - .sql(tc.sql_b) - .build() - .unwrap(); - SyntaxTree::builder() - .dialect(tc.dialect) - .sql(tc.expect) - .build() + let dialect = tc.dialect.clone(); + let ast_a = SyntaxTree::parse(dialect.clone(), tc.sql_a).unwrap(); + let ast_b = SyntaxTree::parse(dialect.clone(), tc.sql_b).unwrap(); + SyntaxTree::parse(dialect, tc.expect) .unwrap_or_else(|_| panic!("invalid SQL: {:?}", tc.expect)); let actual = testfn(ast_a, ast_b) .inspect_err(|err| eprintln!("Error: {err:?}")) @@ -170,7 +151,7 @@ mod tests { use super::*; test_case!( - @dialect(Dialect::Generic) + @dialect(Generic) create_table_a { sql_a: "CREATE TABLE foo(\ @@ -304,7 +285,7 @@ mod tests { ); test_case!( - @dialect(Dialect::Generic) + @dialect(Generic) create_domain_a { sql_a: "", @@ -325,10 +306,12 @@ mod tests { } mod migrate { + use crate::dialect::PostgreSQL; + use super::*; test_case!( - @dialect(Dialect::Generic) + @dialect(Generic) create_table_a { sql_a: "CREATE TABLE bar (id INT PRIMARY KEY);", @@ -451,12 +434,12 @@ mod tests { }, => |ast_a, ast_b| { - ast_a.migrate(&ast_b) + Some(ast_a.migrate(&ast_b)).transpose() } ); test_case!( - @dialect(Dialect::PostgreSql) + @dialect(PostgreSQL) alter_table_alter_column_e { sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)", @@ -471,7 +454,7 @@ mod tests { }, => |ast_a, ast_b| { - ast_a.migrate(&ast_b) + Some(ast_a.migrate(&ast_b)).transpose() } ); } diff --git a/src/migration.rs b/src/migration.rs index ff2d1bd..d355f80 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -1,14 +1,20 @@ use std::fmt; use bon::bon; +use sqlparser::ast::{CreateDomain, CreateIndex}; use thiserror::Error; -use crate::ast::{ - AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition, - AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension, - GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation, +use crate::{ + ast::{ + AlterTable, AlterTableOperation, AlterType, AlterTypeOperation, CreateExtension, + CreateTable, CreateType, Statement, + }, + dialect::{Generic, PostgreSQL, SQLite}, + sealed::Sealed, }; +pub mod generic; + #[derive(Error, Debug)] pub struct MigrateError { kind: MigrateErrorKind, @@ -52,8 +58,6 @@ impl MigrateError { #[derive(Error, Debug)] #[non_exhaustive] enum MigrateErrorKind { - #[error("can't migrate unnamed index")] - UnnamedIndex, #[error("ALTER TABLE operation \"{0}\" not yet supported")] AlterTableOpNotImplemented(Box), #[error("invalid ALTER TYPE operation \"{0}\"")] @@ -62,358 +66,105 @@ enum MigrateErrorKind { NotImplemented, } -pub(crate) trait Migrate: Sized { - fn migrate(self, other: &Self) -> Result, MigrateError>; -} +type Result = std::result::Result; -impl Migrate for Vec { - fn migrate(self, other: &Self) -> Result, MigrateError> { - let next: Self = self - .into_iter() - // perform any transformations on existing schema (e.g. ALTER/DROP table) - .filter_map(|sa| { - let orig = sa.clone(); - match &sa { - Statement::CreateTable(ca) => other - .iter() - .find(|sb| match sb { - Statement::AlterTable(AlterTable { name, .. }) => *name == ca.name, - Statement::Drop { - object_type, names, .. - } => { - *object_type == ObjectType::Table - && names.len() == 1 - && names[0] == ca.name - } - _ => false, - }) - .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - Statement::CreateIndex(a) => other - .iter() - .find(|sb| match sb { - Statement::Drop { - object_type, names, .. - } => { - *object_type == ObjectType::Index - && names.len() == 1 - && Some(&names[0]) == a.name.as_ref() - } - _ => false, - }) - .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - Statement::CreateType { name, .. } => other - .iter() - .find(|sb| match sb { - Statement::AlterType(b) => *name == b.name, - Statement::Drop { - object_type, names, .. - } => { - *object_type == ObjectType::Type - && names.len() == 1 - && names[0] == *name - } - _ => false, - }) - .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - Statement::CreateExtension(CreateExtension { name, .. }) => other - .iter() - .find(|sb| match sb { - Statement::DropExtension(DropExtension { names, .. }) => { - names.contains(name) - } - _ => false, - }) - .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - Statement::CreateDomain(a) => other - .iter() - .find(|sb| match sb { - Statement::DropDomain(b) => a.name == b.name, - _ => false, - }) - .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - _ => Some(Err(MigrateError::builder() - .kind(MigrateErrorKind::NotImplemented) - .statement_a(sa.clone()) - .build())), - } - }) - // CREATE table etc. - .chain(other.iter().filter_map(|sb| match sb { - Statement::CreateTable(_) - | Statement::CreateIndex { .. } - | Statement::CreateType { .. } - | Statement::CreateExtension { .. } - | Statement::CreateDomain(..) => Some(Ok(sb.clone())), - _ => None, - })) - .collect::>()?; - Ok(Some(next)) +pub trait TreeMigrator: StatementMigrator + Sealed { + fn migrate_tree(&self, a: Vec, b: &[Statement]) -> Result> { + generic::tree::migrate_tree(self, a, b) } -} -impl Migrate for Statement { - fn migrate(self, other: &Self) -> Result, MigrateError> { - match self { - Self::CreateTable(ca) => match other { - Self::AlterTable(AlterTable { - name, operations, .. - }) => { - if *name == ca.name { - Ok(Some(Self::CreateTable(migrate_alter_table( - ca, operations, - )?))) - } else { - // ALTER TABLE statement for another table - Ok(Some(Self::CreateTable(ca))) - } - } - Self::Drop { - object_type, names, .. - } => { - if *object_type == ObjectType::Table && names.contains(&ca.name) { - Ok(None) - } else { - // DROP statement is for another table - Ok(Some(Self::CreateTable(ca))) - } - } - _ => Err(MigrateError::builder() - .kind(MigrateErrorKind::NotImplemented) - .statement_a(Self::CreateTable(ca)) - .statement_b(other.clone()) - .build()), - }, - Self::CreateIndex(a) => match other { - Self::Drop { - object_type, names, .. - } => { - let name = a.name.clone().ok_or_else(|| { - MigrateError::builder() - .kind(MigrateErrorKind::UnnamedIndex) - .statement_a(Self::CreateIndex(a.clone())) - .statement_b(other.clone()) - .build() - })?; - if *object_type == ObjectType::Index && names.contains(&name) { - Ok(None) - } else { - // DROP statement is for another index - Ok(Some(Self::CreateIndex(a))) - } - } - _ => Err(MigrateError::builder() - .kind(MigrateErrorKind::NotImplemented) - .statement_a(Self::CreateIndex(a)) - .statement_b(other.clone()) - .build()), - }, - Self::CreateType { - name, - representation, - } => match other { - Self::AlterType(ba) => { - if name == ba.name { - let (name, representation) = - migrate_alter_type(name.clone(), representation.clone(), ba)?; - Ok(Some(Self::CreateType { - name, - representation, - })) - } else { - // ALTER TYPE statement for another type - Ok(Some(Self::CreateType { - name, - representation, - })) - } - } - Self::Drop { - object_type, names, .. - } => { - if *object_type == ObjectType::Type && names.contains(&name) { - Ok(None) - } else { - // DROP statement is for another type - Ok(Some(Self::CreateType { - name, - representation, - })) - } - } - _ => Err(MigrateError::builder() - .kind(MigrateErrorKind::NotImplemented) - .statement_a(Self::CreateType { - name, - representation, - }) - .statement_b(other.clone()) - .build()), - }, - _ => Err(MigrateError::builder() - .kind(MigrateErrorKind::NotImplemented) - .statement_a(self) - .statement_b(other.clone()) - .build()), - } + fn match_and_migrate_create_table( + &self, + sa: &Statement, + a: &CreateTable, + b: &[Statement], + ) -> Result> { + generic::tree::match_and_migrate_create_table(self, sa, a, b) } -} -fn migrate_alter_table( - mut t: CreateTable, - ops: &[AlterTableOperation], -) -> Result { - for op in ops.iter() { - match op { - AlterTableOperation::AddColumn { column_def, .. } => { - t.columns.push(column_def.clone()); - } - AlterTableOperation::DropColumn { column_names, .. } => { - t.columns - .retain(|c| !column_names.iter().any(|name| c.name.value == name.value)); - } - AlterTableOperation::AlterColumn { column_name, op } => { - t.columns.iter_mut().for_each(|c| { - if c.name != *column_name { - return; - } - match op { - AlterColumnOperation::SetNotNull => { - c.options.push(ColumnOptionDef { - name: None, - option: ColumnOption::NotNull, - }); - } - AlterColumnOperation::DropNotNull => { - c.options - .retain(|o| !matches!(o.option, ColumnOption::NotNull)); - } - AlterColumnOperation::SetDefault { value } => { - c.options - .retain(|o| !matches!(o.option, ColumnOption::Default(_))); - c.options.push(ColumnOptionDef { - name: None, - option: ColumnOption::Default(value.clone()), - }); - } - AlterColumnOperation::DropDefault => { - c.options - .retain(|o| !matches!(o.option, ColumnOption::Default(_))); - } - AlterColumnOperation::SetDataType { - data_type, - using: _, // not applicable since we're not running the query - had_set: _, // this doesn't change the meaning - } => { - c.data_type = data_type.clone(); - } - AlterColumnOperation::AddGenerated { - generated_as, - sequence_options, - } => { - c.options - .retain(|o| !matches!(o.option, ColumnOption::Generated { .. })); - c.options.push(ColumnOptionDef { - name: None, - option: ColumnOption::Generated { - generated_as: (*generated_as).unwrap_or(GeneratedAs::Always), - sequence_options: sequence_options.clone(), - generation_expr: None, - generation_expr_mode: None, - generated_keyword: true, - }, - }); - } - } - }); - } - op => { - return Err(MigrateError::builder() - .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new( - op.clone(), - ))) - .statement_a(Statement::CreateTable(t.clone())) - .build()) - } - } + fn match_and_migrate_create_index( + &self, + sa: &Statement, + a: &CreateIndex, + b: &[Statement], + ) -> Result> { + generic::tree::match_and_migrate_create_index(self, sa, a, b) + } + + fn match_and_migrate_create_type( + &self, + sa: &Statement, + a: &CreateType, + b: &[Statement], + ) -> Result> { + generic::tree::match_and_migrate_create_type(self, sa, a, b) } - Ok(t) + fn match_and_migrate_create_extension( + &self, + sa: &Statement, + a: &CreateExtension, + b: &[Statement], + ) -> Result> { + generic::tree::match_and_migrate_create_extension(self, sa, a, b) + } + + fn match_and_migrate_create_domain( + &self, + sa: &Statement, + a: &CreateDomain, + b: &[Statement], + ) -> Result> { + generic::tree::match_and_migrate_create_domain(self, sa, a, b) + } } -fn migrate_alter_type( - name: ObjectName, - representation: Option, - other: &AlterType, -) -> Result<(ObjectName, Option), MigrateError> { - match &other.operation { - AlterTypeOperation::Rename(r) => { - let mut parts = name.0; - parts.pop(); - parts.push(ObjectNamePart::Identifier(r.new_name.clone())); - let name = ObjectName(parts); +impl TreeMigrator for Generic {} - Ok((name, representation)) - } - AlterTypeOperation::AddValue(a) => match representation { - Some(UserDefinedTypeRepresentation::Enum { mut labels }) => { - match &a.position { - Some(AlterTypeAddValuePosition::Before(before_name)) => { - let index = labels - .iter() - .enumerate() - .find(|(_, l)| *l == before_name) - .map(|(i, _)| i) - // insert at the beginning if `before_name` can't be found - .unwrap_or(0); - labels.insert(index, a.value.clone()); - } - Some(AlterTypeAddValuePosition::After(after_name)) => { - let index = labels - .iter() - .enumerate() - .find(|(_, l)| *l == after_name) - .map(|(i, _)| i + 1); - match index { - Some(index) => labels.insert(index, a.value.clone()), - // push it to the end if `after_name` can't be found - None => labels.push(a.value.clone()), - } - } - None => labels.push(a.value.clone()), - } +impl TreeMigrator for PostgreSQL {} - Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels }))) - } - Some(_) | None => Err(MigrateError::builder() - .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( - other.operation.clone(), - ))) - .statement_a(Statement::CreateType { - name, - representation, - }) - .statement_b(Statement::AlterType(other.clone())) - .build()), - }, - AlterTypeOperation::RenameValue(rv) => match representation { - Some(UserDefinedTypeRepresentation::Enum { labels }) => { - let labels = labels - .into_iter() - .map(|l| if l == rv.from { rv.to.clone() } else { l }) - .collect::>(); +impl TreeMigrator for SQLite {} - Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels }))) - } - Some(_) | None => Err(MigrateError::builder() - .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( - other.operation.clone(), - ))) - .statement_a(Statement::CreateType { - name, - representation, - }) - .statement_b(Statement::AlterType(other.clone())) - .build()), - }, +pub trait StatementMigrator: fmt::Debug + Default + Clone + Sized + Sealed { + fn migrate(&self, a: &Statement, b: &Statement) -> Result> { + generic::statement::migrate(self, a, b) + } + + fn migrate_create_table(&self, a: &CreateTable, sb: &Statement) -> Result> { + generic::statement::migrate_create_table(self, a, sb) + } + + fn migrate_alter_table(&self, a: &CreateTable, b: &AlterTable) -> Result> { + generic::statement::migrate_alter_table(self, a, b) + } + + fn migrate_create_index(&self, a: &CreateIndex, sb: &Statement) -> Result> { + generic::statement::migrate_create_index(self, a, sb) + } + + fn migrate_create_type(&self, a: &CreateType, sb: &Statement) -> Result> { + generic::statement::migrate_create_type(self, a, sb) + } + + fn migrate_alter_type(&self, a: &CreateType, b: &AlterType) -> Result> { + generic::statement::migrate_alter_type(self, a, b) + } + + fn migrate_create_extension( + &self, + a: &CreateExtension, + sb: &Statement, + ) -> Result> { + generic::statement::migrate_create_extension(self, a, sb) + } + + fn migrate_create_domain(&self, a: &CreateDomain, sb: &Statement) -> Result> { + generic::statement::migrate_create_domain(self, a, sb) } } + +impl StatementMigrator for Generic {} + +impl StatementMigrator for PostgreSQL {} + +impl StatementMigrator for SQLite {} diff --git a/src/migration/generic.rs b/src/migration/generic.rs new file mode 100644 index 0000000..5b31d1b --- /dev/null +++ b/src/migration/generic.rs @@ -0,0 +1,2 @@ +pub mod statement; +pub mod tree; diff --git a/src/migration/generic/statement.rs b/src/migration/generic/statement.rs new file mode 100644 index 0000000..e630f16 --- /dev/null +++ b/src/migration/generic/statement.rs @@ -0,0 +1,329 @@ +use crate::{ + ast::{ + AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, + AlterTypeAddValuePosition, AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateDomain, + CreateExtension, CreateIndex, CreateTable, CreateType, GeneratedAs, ObjectName, + ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation, + }, + migration::{MigrateError, MigrateErrorKind, Result, StatementMigrator}, +}; + +pub fn migrate( + dialect: &Dialect, + sa: &Statement, + sb: &Statement, +) -> Result> { + match sa { + Statement::CreateTable(a) => dialect.migrate_create_table(a, sb), + Statement::CreateIndex(a) => dialect.migrate_create_index(a, sb), + Statement::CreateType { + name, + representation, + } => dialect.migrate_create_type( + &CreateType { + name: name.clone(), + representation: representation.clone(), + }, + sb, + ), + Statement::CreateExtension(a) => dialect.migrate_create_extension(a, sb), + Statement::CreateDomain(a) => dialect.migrate_create_domain(a, sb), + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(sa.clone()) + .statement_b(sb.clone()) + .build()), + } +} + +pub fn migrate_create_table( + dialect: &Dialect, + a: &CreateTable, + sb: &Statement, +) -> Result> { + match &sb { + Statement::AlterTable(b) => dialect.migrate_alter_table(a, b), + Statement::Drop { + object_type, names, .. + } => { + assert_eq!( + *object_type, + ObjectType::Table, + "attempt to apply non-table DROP to {}", + a.name + ); + assert!( + names.contains(&a.name), + "attempt to apply DROP {:?} to {}", + names, + a.name + ); + Ok(Vec::with_capacity(0)) + } + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(Statement::CreateTable(a.clone())) + .statement_b(sb.clone()) + .build()), + } +} + +pub fn migrate_create_index( + _dialect: &Dialect, + a: &CreateIndex, + sb: &Statement, +) -> Result> { + match sb { + Statement::Drop { + object_type, names, .. + } => { + let name = a + .name + .clone() + .expect("index must be named to apply drop statement"); + assert_eq!( + *object_type, + ObjectType::Index, + "attempt to apply non-index DROP to index {name}" + ); + assert!( + names.contains(&name), + "attempt to apply DROP index {names:?} to {name}" + ); + Ok(Vec::with_capacity(0)) + } + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(Statement::CreateIndex(a.clone())) + .statement_b(sb.clone()) + .build()), + } +} + +pub fn migrate_create_type( + dialect: &Dialect, + a: &CreateType, + sb: &Statement, +) -> Result> { + match sb { + Statement::AlterType(b) => dialect.migrate_alter_type(a, b), + Statement::Drop { + object_type, names, .. + } => { + assert_eq!( + *object_type, + ObjectType::Type, + "attempt to apply non-type DROP to TYPE {}", + a.name + ); + assert!( + names.contains(&a.name), + "attempt to apply DROP {names:?} to {}", + a.name + ); + + Ok(Vec::with_capacity(0)) + } + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(a.clone().into()) + .statement_b(sb.clone()) + .build()), + } +} + +pub fn migrate_create_extension( + _dialect: &Dialect, + _a: &CreateExtension, + _b: &Statement, +) -> Result> { + todo!() +} + +pub fn migrate_create_domain( + _dialect: &Dialect, + _a: &CreateDomain, + _b: &Statement, +) -> Result> { + todo!() +} + +pub fn migrate_alter_table( + _dialect: &Dialect, + a: &CreateTable, + b: &AlterTable, +) -> Result, MigrateError> { + assert_eq!( + a.name, b.name, + "attempt to apply ALTER TABLE {} to {}", + b.name, a.name + ); + + let mut a = a.clone(); + for op in b.operations.iter() { + match op { + AlterTableOperation::AddColumn { column_def, .. } => { + a.columns.push(column_def.clone()); + } + AlterTableOperation::DropColumn { column_names, .. } => { + a.columns + .retain(|c| !column_names.iter().any(|name| c.name.value == name.value)); + } + AlterTableOperation::AlterColumn { column_name, op } => { + a.columns.iter_mut().for_each(|c| { + if c.name != *column_name { + return; + } + match op { + AlterColumnOperation::SetNotNull => { + c.options.push(ColumnOptionDef { + name: None, + option: ColumnOption::NotNull, + }); + } + AlterColumnOperation::DropNotNull => { + c.options + .retain(|o| !matches!(o.option, ColumnOption::NotNull)); + } + AlterColumnOperation::SetDefault { value } => { + c.options + .retain(|o| !matches!(o.option, ColumnOption::Default(_))); + c.options.push(ColumnOptionDef { + name: None, + option: ColumnOption::Default(value.clone()), + }); + } + AlterColumnOperation::DropDefault => { + c.options + .retain(|o| !matches!(o.option, ColumnOption::Default(_))); + } + AlterColumnOperation::SetDataType { + data_type, + using: _, // not applicable since we're not running the query + had_set: _, // this doesn't change the meaning + } => { + c.data_type = data_type.clone(); + } + AlterColumnOperation::AddGenerated { + generated_as, + sequence_options, + } => { + c.options + .retain(|o| !matches!(o.option, ColumnOption::Generated { .. })); + c.options.push(ColumnOptionDef { + name: None, + option: ColumnOption::Generated { + generated_as: (*generated_as).unwrap_or(GeneratedAs::Always), + sequence_options: sequence_options.clone(), + generation_expr: None, + generation_expr_mode: None, + generated_keyword: true, + }, + }); + } + } + }); + } + op => { + return Err(MigrateError::builder() + .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new( + op.clone(), + ))) + .statement_a(Statement::CreateTable(a.clone())) + .build()) + } + } + } + + Ok(vec![Statement::CreateTable(a)]) +} + +pub fn migrate_alter_type( + _dialect: &Dialect, + a: &CreateType, + b: &AlterType, +) -> Result, MigrateError> { + assert_eq!( + a.name, b.name, + "attempt to apply ALTER TYPE {} to {}", + b.name, a.name + ); + + let (name, representation) = match &b.operation { + AlterTypeOperation::Rename(r) => { + let mut parts = a.name.0.clone(); + parts.pop(); + parts.push(ObjectNamePart::Identifier(r.new_name.clone())); + let name = ObjectName(parts); + + Ok((name, a.representation.clone())) + } + AlterTypeOperation::AddValue(av) => match &a.representation { + Some(UserDefinedTypeRepresentation::Enum { labels }) => { + let mut labels = labels.clone(); + match &av.position { + Some(AlterTypeAddValuePosition::Before(before_name)) => { + let index = labels + .iter() + .enumerate() + .find(|(_, l)| *l == before_name) + .map(|(i, _)| i) + // insert at the beginning if `before_name` can't be found + .unwrap_or(0); + labels.insert(index, av.value.clone()); + } + Some(AlterTypeAddValuePosition::After(after_name)) => { + let index = labels + .iter() + .enumerate() + .find(|(_, l)| *l == after_name) + .map(|(i, _)| i + 1); + match index { + Some(index) => labels.insert(index, av.value.clone()), + // push it to the end if `after_name` can't be found + None => labels.push(av.value.clone()), + } + } + None => labels.push(av.value.clone()), + } + + Ok(( + a.name.clone(), + Some(UserDefinedTypeRepresentation::Enum { labels }), + )) + } + Some(_) | None => Err(MigrateError::builder() + .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( + b.operation.clone(), + ))) + .statement_a(a.clone().into()) + .statement_b(Statement::AlterType(b.clone())) + .build()), + }, + AlterTypeOperation::RenameValue(rv) => match &a.representation { + Some(UserDefinedTypeRepresentation::Enum { labels }) => { + let labels = labels + .iter() + .cloned() + .map(|l| if l == rv.from { rv.to.clone() } else { l }) + .collect::>(); + + Ok(( + a.name.clone(), + Some(UserDefinedTypeRepresentation::Enum { labels }), + )) + } + Some(_) | None => Err(MigrateError::builder() + .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( + b.operation.clone(), + ))) + .statement_a(a.clone().into()) + .statement_b(Statement::AlterType(b.clone())) + .build()), + }, + }?; + Ok(vec![Statement::CreateType { + name, + representation, + }]) +} diff --git a/src/migration/generic/tree.rs b/src/migration/generic/tree.rs new file mode 100644 index 0000000..07957a4 --- /dev/null +++ b/src/migration/generic/tree.rs @@ -0,0 +1,142 @@ +use crate::{ + ast::{ + AlterTable, CreateDomain, CreateExtension, CreateIndex, CreateTable, CreateType, + DropExtension, ObjectType, Statement, + }, + migration::{MigrateError, MigrateErrorKind, Result, StatementMigrator, TreeMigrator}, +}; + +pub fn migrate_tree( + dialect: &Dialect, + a: Vec, + b: &[Statement], +) -> Result> { + let next = a + .into_iter() + // perform any transformations on existing schema (e.g. ALTER/DROP table) + .map(|sa| match &sa { + Statement::CreateTable(a) => dialect.match_and_migrate_create_table(&sa, a, b), + Statement::CreateIndex(a) => dialect.match_and_migrate_create_index(&sa, a, b), + Statement::CreateType { + name, + representation, + } => dialect.match_and_migrate_create_type( + &sa, + &CreateType { + name: name.clone(), + representation: representation.clone(), + }, + b, + ), + Statement::CreateExtension(a) => dialect.match_and_migrate_create_extension(&sa, a, b), + Statement::CreateDomain(a) => dialect.match_and_migrate_create_domain(&sa, a, b), + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(sa.clone()) + .build()), + }) + // CREATE table etc. + .chain(b.iter().filter_map(|sb| match sb { + Statement::CreateTable(_) + | Statement::CreateIndex { .. } + | Statement::CreateType { .. } + | Statement::CreateExtension { .. } + | Statement::CreateDomain(..) => Some(Ok(vec![sb.clone()])), + _ => None, + })) + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); + Ok(next) +} + +fn match_and_migrate( + dialect: &Dialect, + sa: &Statement, + b: &[Statement], + match_fn: MF, +) -> Result> +where + Dialect: StatementMigrator, + MF: Fn(&&Statement) -> bool, +{ + b.iter().find(match_fn).map_or_else( + // keep the statement as-is if there's no counterpart + || Ok(vec![sa.clone()]), + // otherwise diff the two statements + |sb| StatementMigrator::migrate(dialect, sa, sb), + ) +} + +pub fn match_and_migrate_create_table( + dialect: &Dialect, + sa: &Statement, + a: &CreateTable, + b: &[Statement], +) -> Result> { + match_and_migrate(dialect, sa, b, |sb| match sb { + Statement::AlterTable(AlterTable { name, .. }) => *name == a.name, + Statement::Drop { + object_type, names, .. + } => *object_type == ObjectType::Table && names.len() == 1 && names[0] == a.name, + _ => false, + }) +} + +pub fn match_and_migrate_create_index( + dialect: &Dialect, + sa: &Statement, + a: &CreateIndex, + b: &[Statement], +) -> Result> { + match_and_migrate(dialect, sa, b, |sb| match sb { + Statement::Drop { + object_type, names, .. + } => { + *object_type == ObjectType::Index + && names.len() == 1 + && Some(&names[0]) == a.name.as_ref() + } + _ => false, + }) +} + +pub fn match_and_migrate_create_type( + dialect: &Dialect, + sa: &Statement, + a: &CreateType, + b: &[Statement], +) -> Result> { + match_and_migrate(dialect, sa, b, |sb| match sb { + Statement::AlterType(b) => a.name == b.name, + Statement::Drop { + object_type, names, .. + } => *object_type == ObjectType::Type && names.len() == 1 && names[0] == a.name, + _ => false, + }) +} + +pub fn match_and_migrate_create_extension( + dialect: &Dialect, + sa: &Statement, + a: &CreateExtension, + b: &[Statement], +) -> Result> { + match_and_migrate(dialect, sa, b, |sb| match sb { + Statement::DropExtension(DropExtension { names, .. }) => names.contains(&a.name), + _ => false, + }) +} + +pub fn match_and_migrate_create_domain( + dialect: &Dialect, + sa: &Statement, + a: &CreateDomain, + b: &[Statement], +) -> Result> { + match_and_migrate(dialect, sa, b, |sb| match sb { + Statement::DropDomain(b) => a.name == b.name, + _ => false, + }) +} diff --git a/src/name_gen.rs b/src/name_gen.rs index 3b13b17..aa6fe21 100644 --- a/src/name_gen.rs +++ b/src/name_gen.rs @@ -7,12 +7,12 @@ use crate::{ }; #[bon::builder(finish_fn = build)] -pub fn generate_name( - #[builder(start_fn)] tree: &SyntaxTree, +pub fn generate_name( + #[builder(start_fn)] tree: &SyntaxTree, max_len: Option, ) -> Option { let mut parts = tree - .0 + .tree .iter() .filter_map(|s| match s { Statement::CreateTable(CreateTable { name, .. }) => Some(format!("create_{name}")), @@ -113,6 +113,7 @@ fn alter_table_name(name: &ObjectName, operations: &[AlterTableOperation]) -> Op #[cfg(test)] mod tests { use super::*; + use crate::dialect; #[derive(Debug)] struct TestCase { @@ -121,7 +122,7 @@ mod tests { } fn run_test_case(tc: &TestCase) { - let tree = SyntaxTree::builder().sql(tc.sql).build().unwrap(); + let tree = SyntaxTree::parse(dialect::Generic, tc.sql).unwrap(); let actual = generate_name(&tree).build(); assert_eq!(actual, Some(tc.name.to_owned()), "{tc:?}"); } diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..c615695 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,49 @@ +use thiserror::Error; + +use crate::{ast, dialect, sealed::Sealed}; + +#[derive(Error, Debug)] +#[error("Oops, we couldn't parse that!")] +pub struct ParseError(#[from] sqlparser::parser::ParserError); + +pub trait Parse: Sealed { + fn parse_sql<'a, Dialect>( + &self, + sql: impl Into<&'a str>, + ) -> Result, ParseError>; +} + +fn parse_sql<'a>( + dialect: Box, + sql: impl Into<&'a str>, +) -> Result, ParseError> { + let tree = sqlparser::parser::Parser::parse_sql(dialect.as_ref(), sql.into())?; + Ok(tree) +} + +impl Parse for dialect::Generic { + fn parse_sql<'a, Dialect>( + &self, + sql: impl Into<&'a str>, + ) -> Result, ParseError> { + parse_sql(Box::new(sqlparser::dialect::GenericDialect {}), sql) + } +} + +impl Parse for dialect::PostgreSQL { + fn parse_sql<'a, Dialect>( + &self, + sql: impl Into<&'a str>, + ) -> Result, ParseError> { + parse_sql(Box::new(sqlparser::dialect::PostgreSqlDialect {}), sql) + } +} + +impl Parse for dialect::SQLite { + fn parse_sql<'a, Dialect>( + &self, + sql: impl Into<&'a str>, + ) -> Result, ParseError> { + parse_sql(Box::new(sqlparser::dialect::SQLiteDialect {}), sql) + } +} diff --git a/src/sealed.rs b/src/sealed.rs new file mode 100644 index 0000000..0650b2f --- /dev/null +++ b/src/sealed.rs @@ -0,0 +1 @@ +pub trait Sealed {} From 9e9eb20617029164474a3aebf06f6bbbae71dd7e Mon Sep 17 00:00:00 2001 From: Jesse Stuart Date: Tue, 14 Apr 2026 18:43:58 -0400 Subject: [PATCH 6/6] fix: add support for DROP {EXTENSION,DOMAIN} --- src/lib.rs | 24 ++++++++++++++--- src/migration/generic/statement.rs | 41 +++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4f06f26..6b42eab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -379,24 +379,30 @@ mod tests { expect: "CREATE UNIQUE INDEX title_idx ON films(title);\n\nCREATE INDEX code_idx ON films(code);", }, - create_index_b { + drop_index_a { sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", sql_b: "DROP INDEX title_idx;", expect: "", }, - create_index_c { + drop_index_b { sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);", sql_b: "DROP INDEX title_idx;CREATE INDEX code_idx ON films (code);", expect: "CREATE INDEX code_idx ON films(code);", }, - alter_create_type_a { + create_type_a { sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", sql_b: "CREATE TYPE compfoo AS (f1 int, f2 text);", expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');\n\nCREATE TYPE compfoo AS (f1 INT, f2 TEXT);", }, + drop_type_a { + sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed'); CREATE TYPE compfoo AS (f1 int, f2 text);", + sql_b: "DROP TYPE bug_status;", + expect: "CREATE TYPE compfoo AS (f1 INT, f2 TEXT);", + }, + alter_type_rename_a { sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');", sql_b: "ALTER TYPE bug_status RENAME TO issue_status", @@ -433,6 +439,12 @@ mod tests { expect: "CREATE EXTENSION hstore;\n\nCREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", }, + drop_extension_a { + sql_a: "CREATE EXTENSION hstore; CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", + sql_b: "DROP EXTENSION hstore;", + expect: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";", + }, + => |ast_a, ast_b| { Some(ast_a.migrate(&ast_b)).transpose() } @@ -453,6 +465,12 @@ mod tests { expect: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0);\n\nCREATE DOMAIN email AS VARCHAR(255) CHECK (\n VALUE ~ '^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n);", }, + drop_domain_a { + sql_a: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0); CREATE DOMAIN above_ten AS INTEGER CHECK (VALUE > 10);", + sql_b: "DROP DOMAIN above_ten;", + expect: "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0);", + }, + => |ast_a, ast_b| { Some(ast_a.migrate(&ast_b)).transpose() } diff --git a/src/migration/generic/statement.rs b/src/migration/generic/statement.rs index e630f16..a0c9031 100644 --- a/src/migration/generic/statement.rs +++ b/src/migration/generic/statement.rs @@ -134,18 +134,47 @@ pub fn migrate_create_type( pub fn migrate_create_extension( _dialect: &Dialect, - _a: &CreateExtension, - _b: &Statement, + a: &CreateExtension, + sb: &Statement, ) -> Result> { - todo!() + match sb { + Statement::DropExtension(b) => { + assert!( + b.names.contains(&a.name), + "attempt to DROP EXTENSION {:?} for {}", + b.names, + a.name + ); + Ok(Vec::with_capacity(0)) + } + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(a.clone().into()) + .statement_b(sb.clone()) + .build()), + } } pub fn migrate_create_domain( _dialect: &Dialect, - _a: &CreateDomain, - _b: &Statement, + a: &CreateDomain, + sb: &Statement, ) -> Result> { - todo!() + match sb { + Statement::DropDomain(b) => { + assert_eq!( + a.name, b.name, + "attempt to DROP DOMAIN {} for {}", + b.name, a.name + ); + Ok(Vec::with_capacity(0)) + } + _ => Err(MigrateError::builder() + .kind(MigrateErrorKind::NotImplemented) + .statement_a(a.clone().into()) + .statement_b(sb.clone()) + .build()), + } } pub fn migrate_alter_table(