diff --git a/database/notsupported.go b/database/notsupported.go index a2dbf16..95748a2 100644 --- a/database/notsupported.go +++ b/database/notsupported.go @@ -1,3 +1,4 @@ +//nolint:revive // ignore unused parameters because these are stubs package database // NotSupportedDB is a db implementation used on database drivers when the diff --git a/mysql/mysql.go b/mysql/mysql.go index 0273491..4e89fb4 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -44,10 +44,22 @@ func (db *DB) Open(dataSourceName string, opt ...database.Option) error { if err != nil { return errors.Wrap(err, "error connecting to mysql") } - _, err = _db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", opts.Database)) + rows, err := _db.Query(fmt.Sprintf("SHOW DATABASES LIKE '%s'", opts.Database)) if err != nil { - return errors.Wrapf(err, "error creating database %s (if not exists)", opts.Database) + return errors.Wrapf(err, "error checking if database %s exists", opts.Database) } + defer rows.Close() + // db doesn't exist, create it + if !rows.Next() { + _, err = _db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", opts.Database)) + if err != nil { + return errors.Wrapf(err, "error creating database %s (if not exists)", opts.Database) + } + } + if err = rows.Err(); err != nil { + return fmt.Errorf("error accessing databases: %w", err) + } + parsedDSN.DBName = opts.Database db.db, err = sql.Open("mysql", parsedDSN.FormatDSN()) if err != nil { @@ -82,6 +94,10 @@ func createTableQry(bucket []byte) string { return fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`(nkey VARBINARY(255), nvalue BLOB, PRIMARY KEY (nkey));", bucket) } +func checkTableQry(bucket []byte) string { + return fmt.Sprintf("SHOW TABLES LIKE '%s';", bucket) +} + func deleteTableQry(bucket []byte) string { return fmt.Sprintf("DROP TABLE `%s`", bucket) } @@ -261,9 +277,20 @@ func (db *DB) Update(tx *database.Tx) error { // CreateTable creates a table in the database. func (db *DB) CreateTable(bucket []byte) error { - _, err := db.db.Exec(createTableQry(bucket)) + rows, err := db.db.Query(checkTableQry(bucket)) if err != nil { - return errors.Wrapf(err, "failed to create table %s", bucket) + return errors.Wrapf(err, "failed to check table %s", bucket) + } + defer rows.Close() + // Table doesn't exist, create it. + if !rows.Next() { + _, err := db.db.Exec(createTableQry(bucket)) + if err != nil { + return errors.Wrapf(err, "failed to create table %s", bucket) + } + } + if err = rows.Err(); err != nil { + return errors.Wrap(err, "error accessing row") } return nil }