Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error {
return err
}

if db.DBSchema().ExistsTable(table.Name()) {
if db.DBSchema().HasTable(table.Name()) {
// if the table already exists, nothing to do here
return nil
}
Expand All @@ -239,7 +239,7 @@ func (m M0001CreateUserTable) Up(migrator *migrations.Migrator) error {
func (m M0001CreateUserTable) Down(migrator *migrations.Migrator) error {
db := migrator.Database()
tableName, _ := schema.NewTableName("user")
if !db.DBSchema().ExistsTable(*tableName) {
if !db.DBSchema().HasTable(*tableName) {
// if the table already doesn't exist, nothing to do here
return nil
}
Expand All @@ -266,7 +266,7 @@ func (m M0002AddUpdatedAtColumn) Up(migrator *migrations.Migrator) error {
return err
}

if db.DBSchema().ExistsTableColumn(*tableName, col.Name()) {
if db.DBSchema().HasColumn(*tableName, col.Name()) {
// if the column already exists, nothing to do here
return nil
}
Expand All @@ -287,7 +287,7 @@ func (m M0002AddUpdatedAtColumn) Down(migrator *migrations.Migrator) error {
return err
}

if !db.DBSchema().ExistsTableColumn(*tableName, *colName) {
if !db.DBSchema().HasColumn(*tableName, *colName) {
// if the column already doesn't exist, nothing to do here
return nil
}
Expand Down Expand Up @@ -334,11 +334,10 @@ When applying migrations with the `migrations.Migrate()` function, a `MigratorCo
To change any of these defaults, use the provided functions to change the values, for example:

```go
config := migrations.NewMigratorConfig(
migrations.WithTableName("custom_migrations"),
migrations.WithColumnNames("migration_name", "applied_on"),
migrations.WithMigrationNameMaxLength(500),
)
config := migrations.NewMigratorConfig().
WithTableName("custom_migrations").
WithColumnNames("migration_name", "applied_on").
WithMigrationNameMaxLength(500)

migrator, err := migrations.Migrate(
db,
Expand Down
16 changes: 8 additions & 8 deletions internal/testutils/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error {
if lError != nil {
return lError
}
if pMigrator.Database().DBSchema().ExistsTable(lUserTable.Name()) {
if pMigrator.Database().DBSchema().HasTable(lUserTable.Name()) {
return nil
}
columns := []schema.NewColumnParams{
Expand Down Expand Up @@ -72,7 +72,7 @@ func (m M001_Create_Table_User) Up(pMigrator *migrations.Migrator) error {
}
}

_, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable))
_, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().CreateTable(*lUserTable))
if lError != nil {
return lError
}
Expand All @@ -84,10 +84,10 @@ func (m M001_Create_Table_User) Down(pMigrator *migrations.Migrator) error {
if lError != nil {
return lError
}
if !pMigrator.Database().DBSchema().ExistsTable(*lUserTableName) {
if !pMigrator.Database().DBSchema().HasTable(*lUserTableName) {
return nil
}
_, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName))
_, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().DropTable(*lUserTableName))
if lError != nil {
return lError
}
Expand All @@ -109,10 +109,10 @@ func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Up(pMigrator *migrations.Mig
if lError != nil {
return lError
}
if pMigrator.Database().DBSchema().ExistsTableColumn(*lUserTableName, lUpdatedAtColumn.Name()) {
if pMigrator.Database().DBSchema().HasColumn(*lUserTableName, lUpdatedAtColumn.Name()) {
return nil
}
_, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn))
_, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableAddColumn(*lUserTableName, *lUpdatedAtColumn))
if lError != nil {
return lError
}
Expand All @@ -128,10 +128,10 @@ func (m M002_Alter_Table_User_Add_Column_UpdatedAt) Down(pMigrator *migrations.M
if lError != nil {
return lError
}
if !pMigrator.Database().DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName) {
if !pMigrator.Database().DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName) {
return nil
}
_, lError = pMigrator.Database().DB().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName))
_, lError = pMigrator.Database().SQLExecutor().Exec(pMigrator.Database().SQLBuilder().AlterTableDropColumn(*lUserTableName, *lUpdatedAtColumnName))
if lError != nil {
return lError
}
Expand Down
34 changes: 14 additions & 20 deletions migrations/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,30 @@ type MigratorConfig struct {
appliedAtColumn string
}

func NewMigratorConfig(pOptions ...func(*MigratorConfig)) MigratorConfig {
func NewMigratorConfig() MigratorConfig {
lConfig := MigratorConfig{
tableName: "__ormshift_migrations",
migrationNameColumn: "name",
migrationNameMaxLength: 250,
appliedAtColumn: "applied_at",
}
for _, o := range pOptions {
o(&lConfig)
}
return lConfig
}

func WithTableName(pTableName string) func(*MigratorConfig) {
return func(mc *MigratorConfig) {
mc.tableName = pTableName
}
func (mc MigratorConfig) WithTableName(pTableName string) MigratorConfig {
mc.tableName = pTableName
return mc
}

func WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) func(*MigratorConfig) {
return func(mc *MigratorConfig) {
mc.migrationNameColumn = pMigrationNameColumn
mc.appliedAtColumn = pAppliedAtColumn
}
func (mc MigratorConfig) WithColumnNames(pMigrationNameColumn, pAppliedAtColumn string) MigratorConfig {
mc.migrationNameColumn = pMigrationNameColumn
mc.appliedAtColumn = pAppliedAtColumn
return mc
}

func WithMigrationNameMaxLength(pMaxLength uint) func(*MigratorConfig) {
return func(mc *MigratorConfig) {
mc.migrationNameMaxLength = pMaxLength
}
func (mc MigratorConfig) WithMigrationNameMaxLength(pMaxLength uint) MigratorConfig {
mc.migrationNameMaxLength = pMaxLength
return mc
}

func (mc MigratorConfig) TableName() string {
Expand All @@ -45,9 +39,9 @@ func (mc MigratorConfig) TableName() string {
func (mc MigratorConfig) MigrationNameColumn() string {
return mc.migrationNameColumn
}
func (mc MigratorConfig) AppliedAtColumn() string {
return mc.appliedAtColumn
}
func (mc MigratorConfig) MigrationNameMaxLength() uint {
return mc.migrationNameMaxLength
}
func (mc MigratorConfig) AppliedAtColumn() string {
return mc.appliedAtColumn
}
11 changes: 6 additions & 5 deletions migrations/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ import (

func TestNewMigratorConfigDefaults(t *testing.T) {
lConfig := migrations.NewMigratorConfig()

testutils.AssertEqualWithLabel(t, "__ormshift_migrations", lConfig.TableName(), "MigratorConfig.TableName")
testutils.AssertEqualWithLabel(t, "name", lConfig.MigrationNameColumn(), "MigratorConfig.MigrationNameColumn")
testutils.AssertEqualWithLabel(t, "applied_at", lConfig.AppliedAtColumn(), "MigratorConfig.AppliedAtColumn")
testutils.AssertEqualWithLabel(t, uint(250), lConfig.MigrationNameMaxLength(), "MigratorConfig.MigrationNameMaxLength")
}

func TestNewMigratorConfigCustom(t *testing.T) {
lConfig := migrations.NewMigratorConfig(
migrations.WithTableName("custom_migrations"),
migrations.WithColumnNames("migration_name", "applied_on"),
migrations.WithMigrationNameMaxLength(500),
)
lConfig := migrations.NewMigratorConfig().
WithTableName("custom_migrations").
WithColumnNames("migration_name", "applied_on").
WithMigrationNameMaxLength(500)

testutils.AssertEqualWithLabel(t, "custom_migrations", lConfig.TableName(), "MigratorConfig.TableName")
testutils.AssertEqualWithLabel(t, "migration_name", lConfig.MigrationNameColumn(), "MigratorConfig.MigrationNameColumn")
testutils.AssertEqualWithLabel(t, "applied_on", lConfig.AppliedAtColumn(), "MigratorConfig.AppliedAtColumn")
Expand Down
4 changes: 2 additions & 2 deletions migrations/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestMigrate(t *testing.T) {
if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") {
return
}
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]")
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]")
testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrationNames)")
}

Expand Down Expand Up @@ -75,7 +75,7 @@ func TestMigrateTwice(t *testing.T) {
if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") {
return
}
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]")
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]")
testutils.AssertEqualWithLabel(t, 2, len(lMigrator.AppliedMigrations()), "len(Migrator.AppliedMigrations)")
}

Expand Down
10 changes: 5 additions & 5 deletions migrations/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (m Migrator) recordAppliedMigration(pMigrationName string) error {
m.config.appliedAtColumn: time.Now().UTC(),
},
)
_, lError := m.database.DB().Exec(q, p...)
_, lError := m.database.SQLExecutor().Exec(q, p...)
return lError
}

Expand All @@ -124,7 +124,7 @@ func (m Migrator) deleteAppliedMigration(pMigrationName string) error {
m.config.migrationNameColumn: pMigrationName,
},
)
_, lError := m.database.DB().Exec(q, p...)
_, lError := m.database.SQLExecutor().Exec(q, p...)
return lError
}

Expand All @@ -144,7 +144,7 @@ func getAppliedMigrationNames(pDatabase ormshift.Database, pConfig MigratorConfi
pConfig.migrationNameColumn,
),
)
lMigrationsRows, lError := pDatabase.DB().Query(q, p...)
lMigrationsRows, lError := pDatabase.SQLExecutor().Query(q, p...)
if lError != nil {
return nil, lError
}
Expand All @@ -169,7 +169,7 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo
if lError != nil {
return lError
}
if !pDatabase.DBSchema().ExistsTable(lMigrationsTable.Name()) {
if !pDatabase.DBSchema().HasTable(lMigrationsTable.Name()) {
columns := []schema.NewColumnParams{
{
Name: pConfig.migrationNameColumn,
Expand All @@ -191,7 +191,7 @@ func ensureMigrationsTableExists(pDatabase ormshift.Database, pConfig MigratorCo
}
}

_, lError = pDatabase.DB().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally
_, lError = pDatabase.SQLExecutor().Exec(pDatabase.SQLBuilder().CreateTable(*lMigrationsTable)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally
}
return lError
}
22 changes: 20 additions & 2 deletions migrations/migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestRevertLastAppliedMigration(t *testing.T) {
if !testutils.AssertNilError(t, lError, "migrations.NewTableName") {
return
}
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().ExistsTable(*lUserTableName), "Migrator.DBSchema.ExistsTable[user]")
testutils.AssertEqualWithLabel(t, true, lDB.DBSchema().HasTable(*lUserTableName), "Migrator.DBSchema.HasTable[user]")

lError = lMigrator.RevertLastAppliedMigration()
if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") {
Expand All @@ -84,7 +84,25 @@ func TestRevertLastAppliedMigration(t *testing.T) {
if !testutils.AssertNilError(t, lError, "migrations.NewColumnName") {
return
}
testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().ExistsTableColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.ExistsTableColumn[user.updated_at]")
testutils.AssertEqualWithLabel(t, false, lDB.DBSchema().HasColumn(*lUserTableName, *lUpdatedAtColumnName), "Migrator.DBSchema.HasColumn[user.updated_at]")
}

func TestRevertLastAppliedMigrationWhenNoMigrationsApplied(t *testing.T) {
lDB, lError := ormshift.OpenDatabase(sqlite.Driver(), ormshift.ConnectionParams{InMemory: true})
if lError != nil {
t.Errorf("ormshift.OpenDatabase failed: %v", lError)
return
}
defer func() { _ = lDB.Close() }()

lMigrator, lError := migrations.NewMigrator(lDB, migrations.NewMigratorConfig())
if !testutils.AssertNotNilResultAndNilError(t, lMigrator, lError, "migrations.NewMigrator") {
return
}
lError = lMigrator.RevertLastAppliedMigration()
if !testutils.AssertNilError(t, lError, "Migrator.RevertLastAppliedMigration") {
return
}
}

func TestRevertLastAppliedMigrationFailsWhenDownFails(t *testing.T) {
Expand Down
38 changes: 11 additions & 27 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
)

Expand All @@ -19,18 +20,14 @@ func NewDBSchema(pDB *sql.DB, pTableNamesQuery string) (*DBSchema, error) {
return &DBSchema{db: pDB, tableNamesQuery: pTableNamesQuery}, nil
}

func (s DBSchema) ExistsTable(pTableName TableName) bool {
func (s DBSchema) HasTable(pTableName TableName) bool {
lTables, lError := s.fetchTableNames()
if lError != nil {
return false
}
for _, lTable := range lTables {
lUpperTableName := strings.ToUpper(lTable)
if lUpperTableName == strings.ToUpper(pTableName.String()) {
return true
}
}
return false
return slices.ContainsFunc(lTables, func(t string) bool {
return strings.EqualFold(t, pTableName.String())
})
}

func (s DBSchema) fetchTableNames() ([]string, error) {
Expand All @@ -55,31 +52,18 @@ func (s DBSchema) fetchTableNames() ([]string, error) {
return lTableNames, lError
}

func (s DBSchema) CheckTableColumnType(pTableName TableName, pColumnName ColumnName) (*sql.ColumnType, error) {
func (s DBSchema) HasColumn(pTableName TableName, pColumnName ColumnName) bool {
lColumnTypes, lError := s.fetchColumnTypes(pTableName)
if lError != nil {
return nil, lError
}
for _, lColumnType := range lColumnTypes {
if lColumnType.Name() == pColumnName.String() {
return lColumnType, nil
}
return false
}
return nil, fmt.Errorf("column %q not found in table %q", pColumnName.String(), pTableName.String())
}

func (s DBSchema) ExistsTableColumn(pTableName TableName, pColumnName ColumnName) bool {
_, lError := s.CheckTableColumnType(pTableName, pColumnName)
return lError == nil
return slices.ContainsFunc(lColumnTypes, func(ct *sql.ColumnType) bool {
return strings.EqualFold(ct.Name(), pColumnName.String())
})
}

func (s DBSchema) fetchColumnTypes(pTableName TableName) ([]*sql.ColumnType, error) {
lTableName := pTableName.String()
if !regexValidTableName.MatchString(lTableName) {
return nil, fmt.Errorf("invalid table name: %q", lTableName)
}

lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", lTableName)) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally
lRows, lError := s.db.Query(fmt.Sprintf("SELECT * FROM %s WHERE 1=0", pTableName.String())) // NOSONAR go:S2077 - Dynamic SQL is controlled and sanitized internally
if lError != nil {
return nil, lError
}
Expand Down
Loading