diff --git a/v2/dbutils/dbutils.go b/v2/dbutils/dbutils.go index 8f66b3d..da81b84 100644 --- a/v2/dbutils/dbutils.go +++ b/v2/dbutils/dbutils.go @@ -154,34 +154,24 @@ type DB struct { // Open opens a database connection given a config struct // It expects a fs.FS in order to fetch and run the DB migrations // If you don't need them, just pass nil instead -func Open(conf DBConfig, fsys fs.FS) (*DB, int64, error) { +func Open(conf DBConfig, fsys fs.FS) (db *DB, err error) { if conf.Driver == "" { - return nil, 0, errors.New("no SQL driver specified: please use one of [mssql,postgres]") + return nil, errors.New("no SQL driver specified: please use one of [mssql,postgres]") } connectionString := fromDBConfToConnectionString(conf) if connectionString == "" { - return nil, 0, errors.New("unsupported driver: " + conf.Driver) + return nil, errors.New("unsupported driver: " + conf.Driver) } - // Init Goose - err := goose.SetDialect(conf.Driver) - if err != nil { - return nil, -1, fmt.Errorf("set migrations dialect: %w", err) - } - - goose.SetBaseFS(fsys) - - currentVersion := int64(-1) - - var db *sql.DB + var sqlDB *sql.DB // Make sure the DB is actually reachable for _, delay := range connectionRetries { - db, err = sql.Open(conf.Driver, connectionString) + sqlDB, err = sql.Open(conf.Driver, connectionString) if err == nil { - err = db.Ping() + err = sqlDB.Ping() if err == nil { break } @@ -189,34 +179,48 @@ func Open(conf DBConfig, fsys fs.FS) (*DB, int64, error) { time.Sleep(delay * time.Second) } if err != nil { - return nil, -1, fmt.Errorf("reaching DB server: %w", err) + return nil, fmt.Errorf("reaching DB server: %w", err) } + db = &DB{DB: sqlDB, conf: conf, fsys: fsys} // DB should be ready, run migrations if needed if conf.Migrations.Run { - // Goose wants to use the "sqlserver" driver, never "mssql" - driver := conf.Driver - if driver == "mssql" { - driver = "sqlserver" - } - db, err := sql.Open(driver, connectionString) - if err != nil { - return nil, -1, fmt.Errorf("open db for migrations: %w", err) + if fsys != nil { + err = db.Migrate(fsys) + if err != nil { + return nil, fmt.Errorf("cannote run db migrations: %w", err) + } } - defer db.Close() + } - currentVersion, err = goose.GetDBVersion(db) - if err != nil { - return nil, -1, fmt.Errorf("get db version: %w", err) - } + return db, nil +} - err = goose.Up(db, conf.Migrations.Path) - if err != nil { - return nil, -1, fmt.Errorf("migrate db: %w", err) - } +func (d *DB) Migrate(fsys fs.FS) (err error) { + if !d.conf.Migrations.Run { + return nil + } + // Init Goose + err = goose.SetDialect(d.conf.Driver) + if err != nil { + return fmt.Errorf("set migrations dialect: %w", err) + } + d.fsys = fsys + goose.SetBaseFS(d.fsys) + goose.SetTableName(goose.DefaultTablename) + + // Goose wants to use the "sqlserver" driver, never "mssql" + driver := d.conf.Driver + if driver == "mssql" { + driver = "sqlserver" } - return &DB{DB: db, conf: conf, fsys: fsys}, currentVersion, nil + err = goose.Up(d.DB, d.conf.Migrations.Path) + if err != nil { + return fmt.Errorf("migrate db: %w", err) + } + + return nil } // Convert the database configuration to connection string @@ -276,7 +280,7 @@ func (d *DB) Up() error { // Version return the current DB version func (d *DB) Version() (int64, error) { if d.fsys == nil { - return -1, errors.New("can't get current version: no file system was passed to Open()") + return -1, nil } return goose.GetDBVersion(d.DB) }