diff --git a/database.go b/database.go index 180de03..b16cd50 100644 --- a/database.go +++ b/database.go @@ -2,9 +2,8 @@ package ormshift import ( "database/sql" - "errors" - "fmt" + "github.com/ordershift/ormshift/errs" "github.com/ordershift/ormshift/schema" ) @@ -35,16 +34,17 @@ type Database struct { func OpenDatabase(driver DatabaseDriver, params ConnectionParams) (*Database, error) { if driver == nil { - return nil, errors.New("DatabaseDriver cannot be nil") + err := errs.Nil("database driver") + return nil, failedToOpenDatabase(err) } connectionString := driver.ConnectionString(params) db, err := sql.Open(driver.Name(), connectionString) if err != nil { - return nil, fmt.Errorf("sql.Open failed: %w", err) + return nil, failedToOpenDatabase(err) } dbSchema, err := driver.DBSchema(db) if err != nil { - return nil, fmt.Errorf("failed to get DB schema: %w", err) + return nil, failedToOpenDatabase(err) } return &Database{ @@ -56,6 +56,10 @@ func OpenDatabase(driver DatabaseDriver, params ConnectionParams) (*Database, er }, nil } +func failedToOpenDatabase(err error) error { + return errs.FailedTo("open database", err) +} + func (d *Database) Close() error { return d.db.Close() } diff --git a/database_test.go b/database_test.go index 557a552..a1518c5 100644 --- a/database_test.go +++ b/database_test.go @@ -33,7 +33,7 @@ func TestOpenDatabaseWithNilDriver(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") { return } - testutils.AssertErrorMessage(t, "DatabaseDriver cannot be nil", err, "ormshift.OpenDatabase") + testutils.AssertErrorMessage(t, "failed to open database: database driver cannot be nil", err, "ormshift.OpenDatabase") } func TestOpenDatabaseWithBadDriver(t *testing.T) { @@ -42,7 +42,7 @@ func TestOpenDatabaseWithBadDriver(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") { return } - testutils.AssertErrorMessage(t, "sql.Open failed: sql: unknown driver \"bad-driver-name\" (forgotten import?)", err, "ormshift.OpenDatabase") + testutils.AssertErrorMessage(t, "failed to open database: sql: unknown driver \"bad-driver-name\" (forgotten import?)", err, "ormshift.OpenDatabase") } func TestOpenDatabaseWithBadSchema(t *testing.T) { @@ -51,7 +51,7 @@ func TestOpenDatabaseWithBadSchema(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") { return } - testutils.AssertErrorMessage(t, "failed to get DB schema: intentionally bad schema", err, "ormshift.OpenDatabase") + testutils.AssertErrorMessage(t, "failed to open database: intentionally bad schema", err, "ormshift.OpenDatabase") } func TestClose(t *testing.T) { diff --git a/dialects/postgresql/driver_test.go b/dialects/postgresql/driver_test.go index 42e7aef..c5537c5 100644 --- a/dialects/postgresql/driver_test.go +++ b/dialects/postgresql/driver_test.go @@ -45,5 +45,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") { return } - testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema") + testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema") } diff --git a/dialects/sqlite/driver_test.go b/dialects/sqlite/driver_test.go index eca7c84..a09915a 100644 --- a/dialects/sqlite/driver_test.go +++ b/dialects/sqlite/driver_test.go @@ -52,5 +52,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") { return } - testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema") + testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema") } diff --git a/dialects/sqlserver/driver_test.go b/dialects/sqlserver/driver_test.go index f22286c..84df702 100644 --- a/dialects/sqlserver/driver_test.go +++ b/dialects/sqlserver/driver_test.go @@ -48,5 +48,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") { return } - testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema") + testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema") } diff --git a/errs/errors.go b/errs/errors.go new file mode 100644 index 0000000..b417d73 --- /dev/null +++ b/errs/errors.go @@ -0,0 +1,42 @@ +package errs + +import ( + "errors" + "fmt" +) + +var ( + ErrInvalid = errors.New("invalid") + ErrNil = errors.New("cannot be nil") + ErrFailedTo = errors.New("failed to") + ErrAlreadyExists = errors.New("already exists") +) + +// Invalid returns an error indicating that value is not valid for the given label. +// The error wraps ErrInvalid, allowing it to be checked with errors.Is. +func Invalid(label string) error { + return fmt.Errorf("%w %s", ErrInvalid, label) +} + +// Nil returns an error indicating that the value identified by label is null. +// The error wraps ErrNil, allowing it to be checked with errors.Is. +func Nil(label string) error { + return fmt.Errorf("%s %w", label, ErrNil) +} + +// FailedTo returns an error indicating a failure to perform the given action. +// It wraps ErrFailedTo and optionally wraps the provided cause. +func FailedTo(action string, err error) error { + failedToErr := fmt.Errorf("%w %s", ErrFailedTo, action) + if err == nil { + return failedToErr + } + return fmt.Errorf("%w: %w", failedToErr, err) +} + +// AlreadyExists returns an error indicating that the resource identified by label +// already exists. +// The error wraps ErrAlreadyExists, allowing it to be checked with errors.Is. +func AlreadyExists(label string) error { + return fmt.Errorf("%s %w", label, ErrAlreadyExists) +} diff --git a/errs/errors_test.go b/errs/errors_test.go new file mode 100644 index 0000000..9e6583a --- /dev/null +++ b/errs/errors_test.go @@ -0,0 +1,48 @@ +package errs_test + +import ( + "testing" + + "github.com/ordershift/ormshift/errs" + "github.com/ordershift/ormshift/internal/testutils" +) + +type errorTester struct { + expectedMessageError string + expectedTypeError error + testedError error +} + +func TestErrors(t *testing.T) { + testers := []errorTester{ + { + expectedMessageError: "invalid driver", + expectedTypeError: errs.ErrInvalid, + testedError: errs.Invalid("driver"), + }, + { + expectedMessageError: "database driver cannot be nil", + expectedTypeError: errs.ErrNil, + testedError: errs.Nil("database driver"), + }, + { + expectedMessageError: "column already exists", + expectedTypeError: errs.ErrAlreadyExists, + testedError: errs.AlreadyExists("column"), + }, + { + expectedMessageError: "failed to get db schema", + expectedTypeError: errs.ErrFailedTo, + testedError: errs.FailedTo("get db schema", nil), + }, + { + expectedMessageError: "failed to get db schema: db cannot be nil", + expectedTypeError: errs.ErrFailedTo, + testedError: errs.FailedTo("get db schema", errs.Nil("db")), + }, + } + for _, tester := range testers { + testutils.AssertErrorType(t, tester.expectedTypeError, tester.testedError) + testutils.AssertErrorMessage(t, tester.expectedMessageError, tester.testedError, "errs pkg") + } +} diff --git a/internal/testutils/assert.go b/internal/testutils/assert.go index f7f5566..fb44544 100644 --- a/internal/testutils/assert.go +++ b/internal/testutils/assert.go @@ -2,6 +2,7 @@ package testutils import ( "database/sql" + "errors" "strings" "testing" ) @@ -21,6 +22,18 @@ func AssertErrorMessage(t *testing.T, expectedErrorMessage string, err error, fu } } +func AssertErrorType(t *testing.T, expectedErrorType, err error) bool { + if !errors.Is(err, expectedErrorType) { + t.Errorf( + "error with message [%s] has not expected type [%s]", + err.Error(), + expectedErrorType.Error(), + ) + return false + } + return true +} + func AssertNotNilResultAndNilError[R any](t *testing.T, result *R, err error, functionName string) bool { res := true if result == nil { diff --git a/migrations/migrations.go b/migrations/migrations.go index a5f2bfe..658a6ef 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -1,6 +1,8 @@ package migrations -import "github.com/ordershift/ormshift" +import ( + "github.com/ordershift/ormshift" +) type Migration interface { Up(migrator *Migrator) error diff --git a/migrations/migrations_test.go b/migrations/migrations_test.go index d85b54d..968c08a 100644 --- a/migrations/migrations_test.go +++ b/migrations/migrations_test.go @@ -80,7 +80,7 @@ func TestMigrateFailsWhenDatabaseIsClosed(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.Migrate") { return } - testutils.AssertErrorMessage(t, "failed to get applied migration names: sql: database is closed", err, "migrations.Migrate") + testutils.AssertErrorMessage(t, "failed to migrate: failed to get applied migration names: sql: database is closed", err, "migrations.Migrate") } func TestMigrateFailsWhenMigrationUpFails(t *testing.T) { diff --git a/migrations/migrator.go b/migrations/migrator.go index 85f7c37..26b650d 100644 --- a/migrations/migrator.go +++ b/migrations/migrator.go @@ -6,6 +6,7 @@ import ( "time" "github.com/ordershift/ormshift" + "github.com/ordershift/ormshift/errs" "github.com/ordershift/ormshift/schema" ) @@ -18,15 +19,18 @@ type Migrator struct { func NewMigrator(database *ormshift.Database, config *MigratorConfig) (*Migrator, error) { if database == nil { - return nil, fmt.Errorf("database cannot be nil") + err := errs.Nil("database") + return nil, failedToMigrate(err) } if config == nil { - return nil, fmt.Errorf("migrator config cannot be nil") + err := errs.Nil("migrator config") + return nil, failedToMigrate(err) } appliedMigrationNames, err := getAppliedMigrationNames(database, config) if err != nil { - return nil, fmt.Errorf("failed to get applied migration names: %w", err) + err := errs.FailedTo("get applied migration names", err) + return nil, failedToMigrate(err) } appliedMigrations := make(map[string]bool, len(appliedMigrationNames)) for _, name := range appliedMigrationNames { @@ -41,6 +45,10 @@ func NewMigrator(database *ormshift.Database, config *MigratorConfig) (*Migrator }, nil } +func failedToMigrate(err error) error { + return errs.FailedTo("migrate", err) +} + func (m *Migrator) Add(migration Migration) { m.migrations = append(m.migrations, migration) } diff --git a/migrations/migrator_test.go b/migrations/migrator_test.go index f72d2d9..c03d9cb 100644 --- a/migrations/migrator_test.go +++ b/migrations/migrator_test.go @@ -13,7 +13,7 @@ import ( func TestNewMigratorWhenDatabaseIsNil(t *testing.T) { migrator, err := migrations.NewMigrator(nil, migrations.NewMigratorConfig()) testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[database=nil]") - testutils.AssertErrorMessage(t, "database cannot be nil", err, "migrations.NewMigrator[database=nil]") + testutils.AssertErrorMessage(t, "failed to migrate: database cannot be nil", err, "migrations.NewMigrator[database=nil]") } func TestNewMigratorWhenConfigIsNil(t *testing.T) { @@ -25,7 +25,7 @@ func TestNewMigratorWhenConfigIsNil(t *testing.T) { migrator, err := migrations.NewMigrator(db, nil) testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[config=nil]") - testutils.AssertErrorMessage(t, "migrator config cannot be nil", err, "migrations.NewMigrator[config=nil]") + testutils.AssertErrorMessage(t, "failed to migrate: migrator config cannot be nil", err, "migrations.NewMigrator[config=nil]") } func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) { @@ -38,7 +38,7 @@ func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) { migrator, err := migrations.NewMigrator(db, migrations.NewMigratorConfig()) testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[database=invalid]") - testutils.AssertErrorMessage(t, "failed to get applied migration names: missing \"=\" after \"invalid-connection-string\" in connection info string\"", err, "migrations.NewMigrator[database=invalid]") + testutils.AssertErrorMessage(t, "failed to migrate: failed to get applied migration names: missing \"=\" after \"invalid-connection-string\" in connection info string\"", err, "migrations.NewMigrator[database=invalid]") } func TestApplyAllMigrationsFailsWhenRecordingFails(t *testing.T) { diff --git a/schema/schema.go b/schema/schema.go index d9c9790..7aa0f57 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -2,9 +2,10 @@ package schema import ( "database/sql" - "errors" "slices" "strings" + + "github.com/ordershift/ormshift/errs" ) type DBSchema struct { @@ -21,7 +22,8 @@ func NewDBSchema( columnTypesQueryFunc ColumnTypesQueryFunc, ) (*DBSchema, error) { if db == nil { - return nil, errors.New("sql.DB cannot be nil") + err := errs.Nil("db") + return nil, failedToGetDBSchema(err) } return &DBSchema{ db: db, @@ -30,6 +32,10 @@ func NewDBSchema( }, nil } +func failedToGetDBSchema(err error) error { + return errs.FailedTo("get db schema", err) +} + func (s *DBSchema) HasTable(table string) bool { tables, err := s.fetchTableNames() if err != nil { diff --git a/schema/schema_test.go b/schema/schema_test.go index 838ce6a..d181ca8 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,7 +32,7 @@ func TestNewDBSchemaFailsWhenDBIsNil(t *testing.T) { if !testutils.AssertNilResultAndNotNilError(t, dbSchema, err, "schema.NewDBSchema") { return } - testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "schema.NewDBSchema") + testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "schema.NewDBSchema") } func TestHasColumn(t *testing.T) { diff --git a/schema/table.go b/schema/table.go index becb105..82c03f8 100644 --- a/schema/table.go +++ b/schema/table.go @@ -4,6 +4,8 @@ import ( "fmt" "slices" "strings" + + "github.com/ordershift/ormshift/errs" ) type Table struct { @@ -33,9 +35,14 @@ func (t *Table) AddColumns(params ...NewColumnParams) error { return strings.EqualFold(column.Name(), c.Name()) }) if exists { - return fmt.Errorf("column %q already exists in table %q", column.Name(), t.Name()) + return failedToAddColumnInTable(*t, column, errs.AlreadyExists("column")) } t.columns = append(t.columns, column) } return nil } + +func failedToAddColumnInTable(table Table, column Column, err error) error { + msg := fmt.Sprintf("add column %q in table %q", column.Name(), table.Name()) + return errs.FailedTo(msg, err) +} diff --git a/schema/table_test.go b/schema/table_test.go index f46d6b1..16134d4 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -17,5 +17,5 @@ func TestAddColumnFailsWhenAlreadyExists(t *testing.T) { if !testutils.AssertNotNilError(t, err, "Table.AddColumns") { return } - testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), err, "Table.AddColumns") + testutils.AssertErrorMessage(t, fmt.Sprintf("failed to add column %q in table %q: column already exists", "value", "product_attribute"), err, "Table.AddColumns") }