diff --git a/README.md b/README.md index bbbfc88..8f39d88 100644 --- a/README.md +++ b/README.md @@ -214,7 +214,7 @@ func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error { return err } - if db.DBSchema().ExistsTable(table.Name()) { + if db.DBSchema().HasTable(table.Name()) { // if the table already exists, nothing to do here return nil } @@ -239,7 +239,7 @@ func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error { func (m M0001CreateUserTable) Down(migrator *migrations.Migrator) error { db := migrator.Database() tableName, _ := schema.NewTableName("user") - if !db.DBSchema().ExistsTable(*tableName) { + if !db.DBSchema().HasTable(*tableName) { // if the table already doesn't exist, nothing to do here return nil } @@ -266,7 +266,7 @@ func (m M0002AddUpdatedAtColumn) Up(migrator *migrations.Migrator) error { return err } - if db.DBSchema().ExistsTableColumn(*tableName, col.Name()) { + if db.DBSchema().HasColumn(*tableName, col.Name()) { // if the column already exists, nothing to do here return nil } @@ -287,7 +287,7 @@ func (m M0002AddUpdatedAtColumn) Down(migrator *migrations.Migrator) error { return err } - if !db.DBSchema().ExistsTableColumn(*tableName, *colName) { + if !db.DBSchema().HasColumn(*tableName, *colName) { // if the column already doesn't exist, nothing to do here return nil } @@ -334,11 +334,10 @@ When applying migrations with the `migrations.Migrate()` function, a `MigratorCo To change any of these defaults, use the provided functions to change the values, for example: ```go -config := migrations.NewMigratorConfig( - migrations.WithTableName("custom_migrations"), - migrations.WithColumnNames("migration_name", "applied_on"), - migrations.WithMigrationNameMaxLength(500), -) +config := migrations.NewMigratorConfig(). + WithTableName("custom_migrations"). + WithColumnNames("migration_name", "applied_on"). + WithMigrationNameMaxLength(500) migrator, err := migrations.Migrate( db, diff --git a/internal/testutils/migrations.go b/internal/testutils/migrations.go index 5e78b6f..8382f9d 100644 --- a/internal/testutils/migrations.go +++ b/internal/testutils/migrations.go @@ -15,7 +15,7 @@ func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error { if lError != nil { return lError } - if pMigrator.Database().DBSchema().ExistsTable(lUserTable.Name()) { + if pMigrator.Database().DBSchema().HasTable(lUserTable.Name()) { return nil } columns := []schema.NewColumnParams{ @@ -72,7 +72,7 @@ func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error { } } - _, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable)) + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable)) if lError != nil { return lError } @@ -84,10 +84,10 @@ func (m M001_Create_Table_User) Down(pMigrator *migrations.Migrator) error { if lError != nil { return lError } - if !pMigrator.Database().DBSchema().ExistsTable(*lUserTableName) { + if !pMigrator.Database().DBSchema().HasTable(*lUserTableName) { return nil } - _, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName)) + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName)) if lError != nil { return lError } @@ -109,10 +109,10 @@ func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Up(pMigrator *migrations.Mig if lError != nil { return lError } - if pMigrator.Database().DBSchema().ExistsTableColumn(*lUserTableName, lUpdatedAtColumn.Name()) { + if pMigrator.Database().DBSchema().HasColumn(*lUserTableName, lUpdatedAtColumn.Name()) { return nil } - _, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn)) + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn)) if lError != nil { return lError } @@ -128,10 +128,10 @@ func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Down(pMigrator *migrations.M if lError != nil { return lError } - if !pMigrator.Database().DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName) { + if !pMigrator.Database().DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName) { return nil } - _, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName)) + _, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName)) if lError != nil { return lError } diff --git a/migrations/config.go b/migrations/config.go index 2826ac4..917ca9d 100644 --- a/migrations/config.go +++ b/migrations/config.go @@ -7,36 +7,30 @@ type MigratorConfig struct { appliedAtColumn string } -func NewMigratorConfig(pOptions ...func(*MigratorConfig)) MigratorConfig { +func NewMigratorConfig() MigratorConfig { lConfig := MigratorConfig{ tableName: "__ormshift_migrations", migrationNameColumn: "name", migrationNameMaxLength: 250, appliedAtColumn: "applied_at", } - for _, o := range pOptions { - o(&lConfig) - } return lConfig } -func WithTableName(pTableName string) func(*MigratorConfig) { - return func(mc *MigratorConfig) { - mc.tableName = pTableName - } +func (mc MigratorConfig) WithTableName(pTableName string) MigratorConfig { + mc.tableName = pTableName + return mc } -func WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) func(*MigratorConfig) { - return func(mc *MigratorConfig) { - mc.migrationNameColumn = pMigrationNameColumn - mc.appliedAtColumn = pAppliedAtColumn - } +func (mc MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) MigratorConfig { + mc.migrationNameColumn = pMigrationNameColumn + mc.appliedAtColumn = pAppliedAtColumn + return mc } -func WithMigrationNameMaxLength(pMaxLength uint) func(*MigratorConfig) { - return func(mc *MigratorConfig) { - mc.migrationNameMaxLength = pMaxLength - } +func (mc MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) MigratorConfig { + mc.migrationNameMaxLength = pMaxLength + return mc } func (mc MigratorConfig) TableName() string { @@ -45,9 +39,9 @@ func (mc MigratorConfig) TableName() string { func (mc MigratorConfig) MigrationNameColumn() string { return mc.migrationNameColumn } -func (mc MigratorConfig) AppliedAtColumn() string { - return mc.appliedAtColumn -} func (mc MigratorConfig) MigrationNameMaxLength() uint { return mc.migrationNameMaxLength } +func (mc MigratorConfig) AppliedAtColumn() string { + return mc.appliedAtColumn +} diff --git a/migrations/config_test.go b/migrations/config_test.go index db26ae2..156aeb2 100644 --- a/migrations/config_test.go +++ b/migrations/config_test.go @@ -9,6 +9,7 @@ import ( func TestNewMigratorConfigDefaults(t *testing.T) { lConfig := migrations.NewMigratorConfig() + testutils.AssertEqualWithLabel(t, "__ormshift_migrations", lConfig.TableName(), "MigratorConfig.TableName") testutils.AssertEqualWithLabel(t, "name", lConfig.MigrationNameColumn(), "MigratorConfig.MigrationNameColumn") testutils.AssertEqualWithLabel(t, "applied_at", lConfig.AppliedAtColumn(), "MigratorConfig.AppliedAtColumn") @@ -16,11 +17,11 @@ func TestNewMigratorConfigDefaults(t *testing.T) { } func TestNewMigratorConfigCustom(t *testing.T) { - lConfig := migrations.NewMigratorConfig( - migrations.WithTableName("custom_migrations"), - migrations.WithColumnNames("migration_name", "applied_on"), - migrations.WithMigrationNameMaxLength(500), - ) + lConfig := migrations.NewMigratorConfig(). + WithTableName("custom_migrations"). + WithColumnNames("migration_name", "applied_on"). + WithMigrationNameMaxLength(500) + testutils.AssertEqualWithLabel(t, "custom_migrations", lConfig.TableName(), "MigratorConfig.TableName") testutils.AssertEqualWithLabel(t, "migration_name", lConfig.MigrationNameColumn(), "MigratorConfig.MigrationNameColumn") testutils.AssertEqualWithLabel(t, "applied_on", lConfig.AppliedAtColumn(), "MigratorConfig.AppliedAtColumn") diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index a1e0dc4..f19f974 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -35,7 +35,7 @@ func TestMigrate(t *testing.T) { if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { return } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]") + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrationNames)") } @@ -75,7 +75,7 @@ func TestMigrateTwice(t *testing.T) { if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { return } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]") + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)") } diff --git a/migrations/migrator.go b/migrations/migrator.go index 36e8ed9..0972f7d 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -113,7 +113,7 @@ func (m Migrator) recordAppliedMigration(pMigrationName string) error { m.config.appliedAtColumn: time.Now().UTC(), }, ) - _, lError := m.database.DB().Exec(q, p...) + _, lError := m.database.SQLExecutor().Exec(q, p...) return lError } @@ -124,7 +124,7 @@ func (m Migrator) deleteAppliedMigration(pMigrationName string) error { m.config.migrationNameColumn: pMigrationName, }, ) - _, lError := m.database.DB().Exec(q, p...) + _, lError := m.database.SQLExecutor().Exec(q, p...) return lError } @@ -144,7 +144,7 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi pConfig.migrationNameColumn, ), ) - lMigrationsRows, lError := pDatabase.DB().Query(q, p...) + lMigrationsRows, lError := pDatabase.SQLExecutor().Query(q, p...) if lError != nil { return nil, lError } @@ -169,7 +169,7 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo if lError != nil { return lError } - if !pDatabase.DBSchema().ExistsTable(lMigrationsTable.Name()) { + if !pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) { columns := []schema.NewColumnParams{ { Name: pConfig.migrationNameColumn, @@ -191,7 +191,7 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo } } - _, lError = pDatabase.DB().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally + _, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally } return lError } diff --git a/migrations/migrator_test.go b/migrations/migrator_test.go index 4663bea..f52d5bc 100644 --- a/migrations/migrator_test.go +++ b/migrations/migrator_test.go @@ -74,7 +74,7 @@ func TestRevertLastAppliedMigration(t *testing.T) { if !testutils.AssertNilError(t, lError, "migrations.NewTableName") { return } - testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTable(*lUserTableName), "Migrator.DBSchema.ExistsTable[user]") + testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(*lUserTableName), "Migrator.DBSchema.HasTable[user]") lError = lMigrator.RevertLastAppliedMigration() if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") { @@ -84,7 +84,25 @@ func TestRevertLastAppliedMigration(t *testing.T) { if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") { return } - testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]") + testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]") +} + +func TestRevertLastAppliedMigrationWhenNoMigrationsApplied(t *testing.T) { + lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) + if lError != nil { + t.Errorf("ormshift.OpenDatabase failed: %v", lError) + return + } + defer func() { _ = lDB.Close() }() + + lMigrator, lError := migrations.NewMigrator(lDB, migrations.NewMigratorConfig()) + if !testutils.AssertNotNilResultAndNilError(t, lMigrator, lError, "migrations.NewMigrator") { + return + } + lError = lMigrator.RevertLastAppliedMigration() + if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") { + return + } } func TestRevertLastAppliedMigrationFailsWhenDownFails(t *testing.T) { diff --git a/schema/schema.go b/schema/schema.go index 2f4067f..2d95994 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "slices" "strings" ) @@ -19,18 +20,14 @@ func NewDBSchema(pDB *sql.DB, pTableNamesQuery string) (*DBSchema, error) { return &DBSchema{db: pDB, tableNamesQuery: pTableNamesQuery}, nil } -func (s DBSchema) ExistsTable(pTableName TableName) bool { +func (s DBSchema) HasTable(pTableName TableName) bool { lTables, lError := s.fetchTableNames() if lError != nil { return false } - for _, lTable := range lTables { - lUpperTableName := strings.ToUpper(lTable) - if lUpperTableName == strings.ToUpper(pTableName.String()) { - return true - } - } - return false + return slices.ContainsFunc(lTables, func(t string) bool { + return strings.EqualFold(t, pTableName.String()) + }) } func (s DBSchema) fetchTableNames() ([]string, error) { @@ -55,31 +52,18 @@ func (s DBSchema) fetchTableNames() ([]string, error) { return lTableNames, lError } -func (s DBSchema) CheckTableColumnType(pTableName TableName, pColumnName ColumnName) (*sql.ColumnType, error) { +func (s DBSchema) HasColumn(pTableName TableName, pColumnName ColumnName) bool { lColumnTypes, lError := s.fetchColumnTypes(pTableName) if lError != nil { - return nil, lError - } - for _, lColumnType := range lColumnTypes { - if lColumnType.Name() == pColumnName.String() { - return lColumnType, nil - } + return false } - return nil, fmt.Errorf("column %q not found in table %q", pColumnName.String(), pTableName.String()) -} - -func (s DBSchema) ExistsTableColumn(pTableName TableName, pColumnName ColumnName) bool { - _, lError := s.CheckTableColumnType(pTableName, pColumnName) - return lError == nil + return slices.ContainsFunc(lColumnTypes, func(ct *sql.ColumnType) bool { + return strings.EqualFold(ct.Name(), pColumnName.String()) + }) } func (s DBSchema) fetchColumnTypes(pTableName TableName) ([]*sql.ColumnType, error) { - lTableName := pTableName.String() - if !regexValidTableName.MatchString(lTableName) { - return nil, fmt.Errorf("invalid table name: %q", lTableName) - } - - lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", lTableName)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally + lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName.String())) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally if lError != nil { return nil, lError } diff --git a/schema/schema_test.go b/schema/schema_test.go index b8ddb7c..2a30e66 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -30,7 +30,7 @@ func TestNewDBSchemaFailsWhenDBIsNil(t *testing.T) { testutils.AssertErrorMessage(t, "sql.DB cannot be nil", lError, "schema.NewDBSchema") } -func TestExistsTableColumn(t *testing.T) { +func TestHasColumn(t *testing.T) { lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) if lError != nil { t.Errorf("ormshift.OpenDatabase failed: %v", lError) @@ -49,24 +49,24 @@ func TestExistsTableColumn(t *testing.T) { } lDBSchema := lDB.DBSchema() - testutils.AssertEqualWithLabel(t, true, lDBSchema.ExistsTable(lProductAttributeTable.Name()), "DBSchema.ExistsTable") + testutils.AssertEqualWithLabel(t, true, lDBSchema.HasTable(lProductAttributeTable.Name()), "DBSchema.HasTable") for _, lColumn := range lProductAttributeTable.Columns() { - testutils.AssertEqualWithLabel(t, true, lDBSchema.ExistsTableColumn(lProductAttributeTable.Name(), lColumn.Name()), "DBSchema.ExistsTableColumn") + testutils.AssertEqualWithLabel(t, true, lDBSchema.HasColumn(lProductAttributeTable.Name(), lColumn.Name()), "DBSchema.HasColumn") } lAnyTableName, lError := schema.NewTableName("any_table") if !testutils.AssertNotNilResultAndNilError(t, lAnyTableName, lError, "ormshift.NewTableName") { return } - testutils.AssertEqualWithLabel(t, false, lDBSchema.ExistsTable(*lAnyTableName), "DBSchema.ExistsTable") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasTable(*lAnyTableName), "DBSchema.HasTable") lAnyColumnName, lError := schema.NewColumnName("any_col") if !testutils.AssertNotNilResultAndNilError(t, lAnyColumnName, lError, "ormshift.NewTableName") { return } - testutils.AssertEqualWithLabel(t, false, lDBSchema.ExistsTableColumn(lProductAttributeTable.Name(), *lAnyColumnName), "DBSchema.ExistsTableColumn") - testutils.AssertEqualWithLabel(t, false, lDBSchema.ExistsTableColumn(*lAnyTableName, *lAnyColumnName), "DBSchema.ExistsTableColumn") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(lProductAttributeTable.Name(), *lAnyColumnName), "DBSchema.HasColumn") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasColumn(*lAnyTableName, *lAnyColumnName), "DBSchema.HasColumn") } -func TestExistsTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { +func TestHasTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true}) if lError != nil { t.Errorf("ormshift.OpenDatabase failed: %v", lError) @@ -87,5 +87,5 @@ func TestExistsTableReturnsFalseWhenDatabaseIsInvalid(t *testing.T) { } _ = lDB.Close() lDBSchema := lDB.DBSchema() - testutils.AssertEqualWithLabel(t, false, lDBSchema.ExistsTable(lProductAttributeTable.Name()), "DBSchema.ExistsTable") + testutils.AssertEqualWithLabel(t, false, lDBSchema.HasTable(lProductAttributeTable.Name()), "DBSchema.HasTable") } diff --git a/schema/table.go b/schema/table.go index 9123432..3f4eca6 100644 --- a/schema/table.go +++ b/schema/table.go @@ -45,9 +45,8 @@ func (t *Table) AddColumn(pParams NewColumnParams) error { if lError != nil { return lError } - lLowerColumnName := strings.ToLower(lColumn.Name().String()) lColumnAlreadyExists := slices.ContainsFunc(t.columns, func(c Column) bool { - return lLowerColumnName == strings.ToLower(c.Name().String()) + return strings.EqualFold(lColumn.Name().String(), c.Name().String()) }) if lColumnAlreadyExists { return fmt.Errorf("column %q already exists in table %q", lColumn.Name().String(), t.name)