Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ camino = "1.1.9"
chrono = "0.4.40"
clap = { version = "4.5.29", features = ["derive"], optional = true }
sqlformat = "0.3.5"
sqlparser = { version = "0.57.0" }
sqlparser = { version = "0.61.0" }
thiserror = "2.0.12"
winnow = "0.7.3"
38 changes: 21 additions & 17 deletions src/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::{cmp::Ordering, collections::HashSet, fmt};

use bon::bon;
use sqlparser::ast::{
AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition,
AlterTypeOperation, CreateDomain, CreateIndex, CreateTable, DropDomain, Ident, ObjectName,
helpers::attached_token::AttachedToken, AlterTable, AlterTableOperation, AlterType,
AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, CreateDomain,
CreateExtension, CreateIndex, CreateTable, DropDomain, DropExtension, Ident, ObjectName,
ObjectType, Statement, UserDefinedTypeRepresentation,
};
use thiserror::Error;
Expand Down Expand Up @@ -82,12 +83,12 @@ impl Diff for Vec<Statement> {
Statement::CreateType { name, .. } => {
find_and_compare_create_type(sa, name, other)
}
Statement::CreateExtension {
Statement::CreateExtension(CreateExtension {
name,
if_not_exists,
cascade,
..
} => {
}) => {
find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
}
Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other),
Expand Down Expand Up @@ -115,9 +116,11 @@ impl Diff for Vec<Statement> {
_ => false,
}))
}
Statement::CreateExtension { name: b_name, .. } => {
Statement::CreateExtension(CreateExtension { name: b_name, .. }) => {
Ok(self.iter().find(|sa| match sa {
Statement::CreateExtension { name: a_name, .. } => a_name == b_name,
Statement::CreateExtension(CreateExtension {
name: a_name, ..
}) => a_name == b_name,
_ => false,
}))
}
Expand Down Expand Up @@ -264,19 +267,19 @@ fn find_and_compare_create_extension(
sa,
other,
|sb| match sb {
Statement::CreateExtension { name: b_name, .. } => a_name == b_name,
Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name,
_ => false,
},
|| {
Ok(Some(vec![Statement::DropExtension {
Ok(Some(vec![Statement::DropExtension(DropExtension {
names: vec![a_name.clone()],
if_exists: if_not_exists,
cascade_or_restrict: if cascade {
Some(sqlparser::ast::ReferentialAction::Cascade)
} else {
None
},
}]))
})]))
},
)
}
Expand Down Expand Up @@ -351,7 +354,7 @@ fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statemen
} else {
// drop column if it only exists in `a`
Some(AlterTableOperation::DropColumn {
column_name: ac.name.clone(),
column_names: vec![ac.name.clone()],
if_exists: a.if_not_exists,
drop_behavior: None,
has_column_keyword: true,
Expand All @@ -377,15 +380,16 @@ fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statemen
return None;
}

Some(vec![Statement::AlterTable {
Some(vec![Statement::AlterTable(AlterTable {
table_type: None,
name: a.name.clone(),
if_exists: a.if_not_exists,
only: false,
operations,
location: None,
on_cluster: a.on_cluster.clone(),
iceberg: false,
}])
end_token: AttachedToken::empty(),
})])
}

fn compare_create_index(
Expand Down Expand Up @@ -423,18 +427,18 @@ fn compare_create_index(
fn compare_create_type(
a: &Statement,
a_name: &ObjectName,
a_rep: &UserDefinedTypeRepresentation,
a_rep: &Option<UserDefinedTypeRepresentation>,
b: &Statement,
b_name: &ObjectName,
b_rep: &UserDefinedTypeRepresentation,
b_rep: &Option<UserDefinedTypeRepresentation>,
) -> Result<Option<Vec<Statement>>, DiffError> {
if a_name == b_name && a_rep == b_rep {
return Ok(None);
}

let operations = match a_rep {
UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep {
UserDefinedTypeRepresentation::Enum { labels: b_labels } => {
Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match b_rep {
Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => {
match a_labels.len().cmp(&b_labels.len()) {
Ordering::Equal => {
let rename_labels: Vec<_> = a_labels
Expand Down
13 changes: 13 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,19 @@ mod tests {
);
}

#[test]
fn apply_alter_table_drop_columns_snowflake() {
run_test_cases(
vec![TestCase {
dialect: Dialect::Snowflake,
sql_a: "CREATE TABLE bar (foo TEXT, bar TEXT, id INT PRIMARY KEY)",
sql_b: "ALTER TABLE bar DROP COLUMN foo, bar",
expect: "CREATE TABLE bar (id INT PRIMARY KEY);",
}],
|ast_a, ast_b| ast_a.migrate(&ast_b),
);
}

#[test]
fn apply_alter_table_alter_column() {
run_test_cases(
Expand Down
47 changes: 26 additions & 21 deletions src/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::fmt;

use bon::bon;
use sqlparser::ast::{
AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension,
GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
};
use thiserror::Error;

Expand Down Expand Up @@ -76,7 +76,7 @@ impl Migrate for Vec<Statement> {
Statement::CreateTable(ca) => other
.iter()
.find(|sb| match sb {
Statement::AlterTable { name, .. } => *name == ca.name,
Statement::AlterTable(AlterTable { name, .. }) => *name == ca.name,
Statement::Drop {
object_type, names, ..
} => {
Expand Down Expand Up @@ -114,10 +114,12 @@ impl Migrate for Vec<Statement> {
_ => false,
})
.map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
Statement::CreateExtension { name, .. } => other
Statement::CreateExtension(CreateExtension { name, .. }) => other
.iter()
.find(|sb| match sb {
Statement::DropExtension { names, .. } => names.contains(name),
Statement::DropExtension(DropExtension { names, .. }) => {
names.contains(name)
}
_ => false,
})
.map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
Expand Down Expand Up @@ -152,9 +154,9 @@ impl Migrate for Statement {
fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
match self {
Self::CreateTable(ca) => match other {
Self::AlterTable {
Self::AlterTable(AlterTable {
name, operations, ..
} => {
}) => {
if *name == ca.name {
Ok(Some(Self::CreateTable(migrate_alter_table(
ca, operations,
Expand Down Expand Up @@ -264,8 +266,11 @@ fn migrate_alter_table(
AlterTableOperation::AddColumn { column_def, .. } => {
t.columns.push(column_def.clone());
}
AlterTableOperation::DropColumn { column_name, .. } => {
t.columns.retain(|c| c.name.value != *column_name.value);
AlterTableOperation::DropColumn { column_names, .. } => {
t.columns.retain(|c| {
!column_names
.iter().any(|name| c.name.value == name.value)
});
}
AlterTableOperation::AlterColumn { column_name, op } => {
t.columns.iter_mut().for_each(|c| {
Expand Down Expand Up @@ -297,7 +302,8 @@ fn migrate_alter_table(
}
AlterColumnOperation::SetDataType {
data_type,
using: _, // not applicable since we're not running the query
using: _, // not applicable since we're not running the query
had_set: _, // this doesn't change the meaning
} => {
c.data_type = data_type.clone();
}
Expand All @@ -310,8 +316,7 @@ fn migrate_alter_table(
c.options.push(ColumnOptionDef {
name: None,
option: ColumnOption::Generated {
generated_as: generated_as
.clone()
generated_as: (*generated_as)
.unwrap_or(GeneratedAs::Always),
sequence_options: sequence_options.clone(),
generation_expr: None,
Expand Down Expand Up @@ -339,9 +344,9 @@ fn migrate_alter_table(

fn migrate_alter_type(
name: ObjectName,
representation: UserDefinedTypeRepresentation,
representation: Option<UserDefinedTypeRepresentation>,
other: &AlterType,
) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
) -> Result<(ObjectName, Option<UserDefinedTypeRepresentation>), MigrateError> {
match &other.operation {
AlterTypeOperation::Rename(r) => {
let mut parts = name.0;
Expand All @@ -352,7 +357,7 @@ fn migrate_alter_type(
Ok((name, representation))
}
AlterTypeOperation::AddValue(a) => match representation {
UserDefinedTypeRepresentation::Enum { mut labels } => {
Some(UserDefinedTypeRepresentation::Enum { mut labels }) => {
match &a.position {
Some(AlterTypeAddValuePosition::Before(before_name)) => {
let index = labels
Expand All @@ -379,9 +384,9 @@ fn migrate_alter_type(
None => labels.push(a.value.clone()),
}

Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
}
UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
Some(_) | None => Err(MigrateError::builder()
.kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
other.operation.clone(),
)))
Expand All @@ -393,15 +398,15 @@ fn migrate_alter_type(
.build()),
},
AlterTypeOperation::RenameValue(rv) => match representation {
UserDefinedTypeRepresentation::Enum { labels } => {
Some(UserDefinedTypeRepresentation::Enum { labels }) => {
let labels = labels
.into_iter()
.map(|l| if l == rv.from { rv.to.clone() } else { l })
.collect::<Vec<_>>();

Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
}
UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
Some(_) | None => Err(MigrateError::builder()
.kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
other.operation.clone(),
)))
Expand Down
27 changes: 19 additions & 8 deletions src/name_gen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use sqlparser::ast::{
AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, ObjectName, ObjectType,
Statement,
AlterTable, AlterTableOperation, AlterType, ColumnDef, CreateIndex, CreateTable, ObjectName,
ObjectType, RenameTableNameKind, Statement,
};

use crate::SyntaxTree;
Expand All @@ -15,9 +15,9 @@ pub fn generate_name(
.iter()
.filter_map(|s| match s {
Statement::CreateTable(CreateTable { name, .. }) => Some(format!("create_{name}")),
Statement::AlterTable {
Statement::AlterTable(AlterTable {
name, operations, ..
} => alter_table_name(name, operations),
}) => alter_table_name(name, operations),
Statement::Drop {
object_type, names, ..
} => {
Expand Down Expand Up @@ -73,9 +73,14 @@ fn alter_table_name(name: &ObjectName, operations: &[AlterTableOperation]) -> Op
column_def: ColumnDef { name, .. },
..
} => Some(format!("add_{name}")),
AlterTableOperation::DropColumn { column_name, .. } => {
Some(format!("drop_{column_name}"))
}
AlterTableOperation::DropColumn { column_names, .. } => Some(format!(
"drop_{}",
column_names
.iter()
.map(|ident| ident.value.clone())
.collect::<Vec<_>>()
.join("_")
)),
AlterTableOperation::RenameColumn {
old_column_name,
new_column_name,
Expand All @@ -85,7 +90,13 @@ fn alter_table_name(name: &ObjectName, operations: &[AlterTableOperation]) -> Op
}
AlterTableOperation::RenameTable { table_name } => {
table_verb = "rename";
Some(format!("to_{table_name}"))
Some(format!(
"to_{table_name}",
table_name = match table_name {
RenameTableNameKind::As(name) => name,
RenameTableNameKind::To(name) => name,
}
))
}
_ => None,
})
Expand Down
Loading