diff --git a/Cargo.lock b/Cargo.lock index 8ae3bda..958281b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -429,9 +429,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.57.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c5f081b292a3d19637f0b32a79e28ff14a9fd23ef47bd7fce08ff5de221eca" +checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", "recursive", diff --git a/Cargo.toml b/Cargo.toml index 86b6c18..d8786bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,6 @@ camino = "1.1.9" chrono = "0.4.40" clap = { version = "4.5.29", features = ["derive"], optional = true } sqlformat = "0.3.5" -sqlparser = { version = "0.57.0" } +sqlparser = { version = "0.61.0" } thiserror = "2.0.12" winnow = "0.7.3" diff --git a/src/diff.rs b/src/diff.rs index 7a5a66a..26bbac6 100644 --- a/src/diff.rs +++ b/src/diff.rs @@ -2,8 +2,9 @@ use std::{cmp::Ordering, collections::HashSet, fmt}; use bon::bon; use sqlparser::ast::{ - AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, - AlterTypeOperation, CreateDomain, CreateIndex, CreateTable, DropDomain, Ident, ObjectName, + helpers::attached_token::AttachedToken, AlterTable, AlterTableOperation, AlterType, + AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, CreateDomain, + CreateExtension, CreateIndex, CreateTable, DropDomain, DropExtension, Ident, ObjectName, ObjectType, Statement, UserDefinedTypeRepresentation, }; use thiserror::Error; @@ -82,12 +83,12 @@ impl Diff for Vec { Statement::CreateType { name, .. } => { find_and_compare_create_type(sa, name, other) } - Statement::CreateExtension { + Statement::CreateExtension(CreateExtension { name, if_not_exists, cascade, .. - } => { + }) => { find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other) } Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other), @@ -115,9 +116,11 @@ impl Diff for Vec { _ => false, })) } - Statement::CreateExtension { name: b_name, .. } => { + Statement::CreateExtension(CreateExtension { name: b_name, .. }) => { Ok(self.iter().find(|sa| match sa { - Statement::CreateExtension { name: a_name, .. } => a_name == b_name, + Statement::CreateExtension(CreateExtension { + name: a_name, .. + }) => a_name == b_name, _ => false, })) } @@ -264,11 +267,11 @@ fn find_and_compare_create_extension( sa, other, |sb| match sb { - Statement::CreateExtension { name: b_name, .. } => a_name == b_name, + Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name, _ => false, }, || { - Ok(Some(vec![Statement::DropExtension { + Ok(Some(vec![Statement::DropExtension(DropExtension { names: vec![a_name.clone()], if_exists: if_not_exists, cascade_or_restrict: if cascade { @@ -276,7 +279,7 @@ fn find_and_compare_create_extension( } else { None }, - }])) + })])) }, ) } @@ -351,7 +354,7 @@ fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option Option, b: &Statement, b_name: &ObjectName, - b_rep: &UserDefinedTypeRepresentation, + b_rep: &Option, ) -> Result>, DiffError> { if a_name == b_name && a_rep == b_rep { return Ok(None); } let operations = match a_rep { - UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep { - UserDefinedTypeRepresentation::Enum { labels: b_labels } => { + Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match b_rep { + Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => { match a_labels.len().cmp(&b_labels.len()) { Ordering::Equal => { let rename_labels: Vec<_> = a_labels diff --git a/src/lib.rs b/src/lib.rs index b3a40b8..963127c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -429,6 +429,19 @@ mod tests { ); } + #[test] + fn apply_alter_table_drop_columns_snowflake() { + run_test_cases( + vec![TestCase { + dialect: Dialect::Snowflake, + sql_a: "CREATE TABLE bar (foo TEXT, bar TEXT, id INT PRIMARY KEY)", + sql_b: "ALTER TABLE bar DROP COLUMN foo, bar", + expect: "CREATE TABLE bar (id INT PRIMARY KEY);", + }], + |ast_a, ast_b| ast_a.migrate(&ast_b), + ); + } + #[test] fn apply_alter_table_alter_column() { run_test_cases( diff --git a/src/migration.rs b/src/migration.rs index 90a1a82..72c9aa2 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -2,9 +2,9 @@ use std::fmt; use bon::bon; use sqlparser::ast::{ - AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition, - AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName, - ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation, + AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition, + AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension, + GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation, }; use thiserror::Error; @@ -76,7 +76,7 @@ impl Migrate for Vec { Statement::CreateTable(ca) => other .iter() .find(|sb| match sb { - Statement::AlterTable { name, .. } => *name == ca.name, + Statement::AlterTable(AlterTable { name, .. }) => *name == ca.name, Statement::Drop { object_type, names, .. } => { @@ -114,10 +114,12 @@ impl Migrate for Vec { _ => false, }) .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), - Statement::CreateExtension { name, .. } => other + Statement::CreateExtension(CreateExtension { name, .. }) => other .iter() .find(|sb| match sb { - Statement::DropExtension { names, .. } => names.contains(name), + Statement::DropExtension(DropExtension { names, .. }) => { + names.contains(name) + } _ => false, }) .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()), @@ -152,9 +154,9 @@ impl Migrate for Statement { fn migrate(self, other: &Self) -> Result, MigrateError> { match self { Self::CreateTable(ca) => match other { - Self::AlterTable { + Self::AlterTable(AlterTable { name, operations, .. - } => { + }) => { if *name == ca.name { Ok(Some(Self::CreateTable(migrate_alter_table( ca, operations, @@ -264,8 +266,11 @@ fn migrate_alter_table( AlterTableOperation::AddColumn { column_def, .. } => { t.columns.push(column_def.clone()); } - AlterTableOperation::DropColumn { column_name, .. } => { - t.columns.retain(|c| c.name.value != *column_name.value); + AlterTableOperation::DropColumn { column_names, .. } => { + t.columns.retain(|c| { + !column_names + .iter().any(|name| c.name.value == name.value) + }); } AlterTableOperation::AlterColumn { column_name, op } => { t.columns.iter_mut().for_each(|c| { @@ -297,7 +302,8 @@ fn migrate_alter_table( } AlterColumnOperation::SetDataType { data_type, - using: _, // not applicable since we're not running the query + using: _, // not applicable since we're not running the query + had_set: _, // this doesn't change the meaning } => { c.data_type = data_type.clone(); } @@ -310,8 +316,7 @@ fn migrate_alter_table( c.options.push(ColumnOptionDef { name: None, option: ColumnOption::Generated { - generated_as: generated_as - .clone() + generated_as: (*generated_as) .unwrap_or(GeneratedAs::Always), sequence_options: sequence_options.clone(), generation_expr: None, @@ -339,9 +344,9 @@ fn migrate_alter_table( fn migrate_alter_type( name: ObjectName, - representation: UserDefinedTypeRepresentation, + representation: Option, other: &AlterType, -) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> { +) -> Result<(ObjectName, Option), MigrateError> { match &other.operation { AlterTypeOperation::Rename(r) => { let mut parts = name.0; @@ -352,7 +357,7 @@ fn migrate_alter_type( Ok((name, representation)) } AlterTypeOperation::AddValue(a) => match representation { - UserDefinedTypeRepresentation::Enum { mut labels } => { + Some(UserDefinedTypeRepresentation::Enum { mut labels }) => { match &a.position { Some(AlterTypeAddValuePosition::Before(before_name)) => { let index = labels @@ -379,9 +384,9 @@ fn migrate_alter_type( None => labels.push(a.value.clone()), } - Ok((name, UserDefinedTypeRepresentation::Enum { labels })) + Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels }))) } - UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder() + Some(_) | None => Err(MigrateError::builder() .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( other.operation.clone(), ))) @@ -393,15 +398,15 @@ fn migrate_alter_type( .build()), }, AlterTypeOperation::RenameValue(rv) => match representation { - UserDefinedTypeRepresentation::Enum { labels } => { + Some(UserDefinedTypeRepresentation::Enum { labels }) => { let labels = labels .into_iter() .map(|l| if l == rv.from { rv.to.clone() } else { l }) .collect::>(); - Ok((name, UserDefinedTypeRepresentation::Enum { labels })) + Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels }))) } - UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder() + Some(_) | None => Err(MigrateError::builder() .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new( other.operation.clone(), ))) diff --git a/src/name_gen.rs b/src/name_gen.rs index 911bcdc..0825879 100644 --- a/src/name_gen.rs +++ b/src/name_gen.rs @@ -1,6 +1,6 @@ use sqlparser::ast::{ - AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, ObjectName, ObjectType, - Statement, + AlterTable, AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, ObjectName, + ObjectType, RenameTableNameKind, Statement, }; use crate::SyntaxTree; @@ -15,9 +15,9 @@ pub fn generate_name( .iter() .filter_map(|s| match s { Statement::CreateTable(CreateTable { name, .. }) => Some(format!("create_{name}")), - Statement::AlterTable { + Statement::AlterTable(AlterTable { name, operations, .. - } => alter_table_name(name, operations), + }) => alter_table_name(name, operations), Statement::Drop { object_type, names, .. } => { @@ -73,9 +73,14 @@ fn alter_table_name(name: &ObjectName, operations: &[AlterTableOperation]) -> Op column_def: ColumnDef { name, .. }, .. } => Some(format!("add_{name}")), - AlterTableOperation::DropColumn { column_name, .. } => { - Some(format!("drop_{column_name}")) - } + AlterTableOperation::DropColumn { column_names, .. } => Some(format!( + "drop_{}", + column_names + .iter() + .map(|ident| ident.value.clone()) + .collect::>() + .join("_") + )), AlterTableOperation::RenameColumn { old_column_name, new_column_name, @@ -85,7 +90,13 @@ fn alter_table_name(name: &ObjectName, operations: &[AlterTableOperation]) -> Op } AlterTableOperation::RenameTable { table_name } => { table_verb = "rename"; - Some(format!("to_{table_name}")) + Some(format!( + "to_{table_name}", + table_name = match table_name { + RenameTableNameKind::As(name) => name, + RenameTableNameKind::To(name) => name, + } + )) } _ => None, })