From f2bf62794eb40a47a119ec4d1104ee3cf3958eab Mon Sep 17 00:00:00 2001 From: Glen Mailer Date: Tue, 10 Mar 2026 19:50:02 +0000 Subject: [PATCH] feat: add no-transaction migration support Migrations containing `-- pgmigrate: no-transaction` on the first line will be executed outside of a transaction, enabling operations like `CREATE INDEX CONCURRENTLY` that cannot run inside a transaction block. Co-Authored-By: Claude Opus 4.6 --- migration.go | 17 ++++++++- migration_test.go | 39 ++++++++++++++++++++ migrator.go | 93 +++++++++++++++++++++++++---------------------- migrator_test.go | 60 ++++++++++++++++++++++++++++++ pgmigrate.go | 1 + 5 files changed, 165 insertions(+), 45 deletions(-) diff --git a/migration.go b/migration.go index 808050c..94d7d6b 100644 --- a/migration.go +++ b/migration.go @@ -11,8 +11,9 @@ import ( // Migration represents a single SQL migration. type Migration struct { - ID string // the filename of the migration, without the .sql extension - SQL string // the contents of the migration file + ID string // the filename of the migration, without the .sql extension + SQL string // the contents of the migration file + NoTransaction bool // if true, the migration will not be wrapped in a transaction } // MD5 computes the MD5 hash of the SQL for this migration so that it can be @@ -31,6 +32,18 @@ type AppliedMigration struct { AppliedAt time.Time // When the migration was run } +// noTransactionPrefix is the magic comment that, when present on the first +// line of a migration file, indicates the migration should not be wrapped in a +// transaction. This is useful for operations like CREATE INDEX CONCURRENTLY. +const noTransactionPrefix = "-- pgmigrate: no-transaction" + +// parseNoTransaction checks if the SQL content starts with the magic comment +// "-- pgmigrate: no-transaction" on the first line. +func parseNoTransaction(sql string) bool { + firstLine, _, _ := strings.Cut(sql, "\n") + return strings.TrimSpace(firstLine) == noTransactionPrefix +} + // IDFromFilename removes directory paths and extensions from the filename to // return just the filename (no extension). // diff --git a/migration_test.go b/migration_test.go index 15d2402..d23a20b 100644 --- a/migration_test.go +++ b/migration_test.go @@ -6,6 +6,45 @@ import ( "github.com/peterldowns/testy/check" ) +func TestParseNoTransaction(t *testing.T) { + t.Parallel() + t.Run("present on first line", func(t *testing.T) { + t.Parallel() + sql := "-- pgmigrate: no-transaction\nCREATE INDEX CONCURRENTLY idx ON users (name);" + check.Equal(t, true, parseNoTransaction(sql)) + }) + t.Run("absent", func(t *testing.T) { + t.Parallel() + sql := "CREATE TABLE users (id int);" + check.Equal(t, false, parseNoTransaction(sql)) + }) + t.Run("on second line", func(t *testing.T) { + t.Parallel() + sql := "-- some other comment\n-- pgmigrate: no-transaction\nCREATE INDEX CONCURRENTLY idx ON users (name);" + check.Equal(t, false, parseNoTransaction(sql)) + }) + t.Run("with trailing whitespace", func(t *testing.T) { + t.Parallel() + sql := "-- pgmigrate: no-transaction \nCREATE INDEX CONCURRENTLY idx ON users (name);" + check.Equal(t, true, parseNoTransaction(sql)) + }) + t.Run("with leading whitespace", func(t *testing.T) { + t.Parallel() + sql := " -- pgmigrate: no-transaction\nCREATE INDEX CONCURRENTLY idx ON users (name);" + check.Equal(t, true, parseNoTransaction(sql)) + }) + t.Run("only the comment no newline", func(t *testing.T) { + t.Parallel() + sql := "-- pgmigrate: no-transaction" + check.Equal(t, true, parseNoTransaction(sql)) + }) + t.Run("similar but different comment", func(t *testing.T) { + t.Parallel() + sql := "-- pgmigrate: no-transactions\nCREATE INDEX CONCURRENTLY idx ON users (name);" + check.Equal(t, false, parseNoTransaction(sql)) + }) +} + func TestIDFromFilename(t *testing.T) { t.Parallel() check.Equal(t, "0001_initial", IDFromFilename("0001_initial.sql")) diff --git a/migrator.go b/migrator.go index 22f8dc9..2164003 100644 --- a/migrator.go +++ b/migrator.go @@ -291,11 +291,9 @@ func (m *Migrator) inTx(ctx context.Context, db Executor, cb func(tx *sql.Tx) er return cb(tx) } -// applyMigration runs a single migration inside a transaction: -// - BEGIN; -// - apply the migration -// - insert a record marking the migration as applied -// - COMMIT; +// applyMigration runs a single migration. If the migration has NoTransaction +// set, it runs the SQL and inserts the record directly on the connection +// without a transaction. Otherwise, it wraps both operations in a transaction. func (m *Migrator) applyMigration(ctx context.Context, db Executor, migration Migration) error { startedAt := time.Now().UTC() fields := []LogField{ @@ -303,49 +301,58 @@ func (m *Migrator) applyMigration(ctx context.Context, db Executor, migration Mi {Key: "migration_checksum", Value: migration.MD5()}, {Key: "started_at", Value: startedAt}, } + if migration.NoTransaction { + fields = append(fields, LogField{Key: "no_transaction", Value: true}) + m.warn(ctx, "applying migration without transaction — if the SQL succeeds but recording fails, the migration will be applied but not tracked", fields...) + return m.applyMigrationSQL(ctx, db, migration, startedAt, fields) + } m.info(ctx, "applying migration", fields...) return m.inTx(ctx, db, func(tx *sql.Tx) error { - // Run the migration SQL - _, err := tx.ExecContext(ctx, migration.SQL) - finishedAt := time.Now().UTC() - executionTimeMs := finishedAt.Sub(startedAt).Milliseconds() - fields = append(fields, - LogField{Key: "execution_time_ms", Value: executionTimeMs}, - LogField{Key: "finished_at", Value: finishedAt}, - ) - if err != nil { - msg := "failed to apply migration" - for key, val := range pgtools.ErrorData(err) { - fields = append(fields, LogField{Key: key, Value: val}) - } - m.error(ctx, err, msg, fields...) - return fmt.Errorf("%s: %w", msg, err) - } - m.info(ctx, "migration succeeded", fields...) - // Mark the migration as applied - applied := AppliedMigration{Migration: migration} - applied.Checksum = migration.MD5() - applied.ExecutionTimeInMillis = executionTimeMs - applied.AppliedAt = startedAt - query := fmt.Sprintf(` - INSERT INTO %s - ( id, checksum, execution_time_in_millis, applied_at ) - VALUES - ( $1, $2, $3, $4 )`, - pgtools.Identifier(m.TableName), - ) - m.debug(ctx, query) - _, err = tx.ExecContext(ctx, query, applied.ID, applied.Checksum, applied.ExecutionTimeInMillis, applied.AppliedAt) - if err != nil { - msg := "failed to mark migration as applied" - m.error(ctx, err, msg, fields...) - return fmt.Errorf("%s: %w", msg, err) - } - m.info(ctx, "marked as applied", fields...) - return nil + return m.applyMigrationSQL(ctx, tx, migration, startedAt, fields) }) } +func (m *Migrator) applyMigrationSQL(ctx context.Context, exec interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +}, migration Migration, startedAt time.Time, fields []LogField) error { + _, err := exec.ExecContext(ctx, migration.SQL) + finishedAt := time.Now().UTC() + executionTimeMs := finishedAt.Sub(startedAt).Milliseconds() + fields = append(fields, + LogField{Key: "execution_time_ms", Value: executionTimeMs}, + LogField{Key: "finished_at", Value: finishedAt}, + ) + if err != nil { + msg := "failed to apply migration" + for key, val := range pgtools.ErrorData(err) { + fields = append(fields, LogField{Key: key, Value: val}) + } + m.error(ctx, err, msg, fields...) + return fmt.Errorf("%s: %w", msg, err) + } + m.info(ctx, "migration succeeded", fields...) + applied := AppliedMigration{Migration: migration} + applied.Checksum = migration.MD5() + applied.ExecutionTimeInMillis = executionTimeMs + applied.AppliedAt = startedAt + query := fmt.Sprintf(` + INSERT INTO %s + ( id, checksum, execution_time_in_millis, applied_at ) + VALUES + ( $1, $2, $3, $4 )`, + pgtools.Identifier(m.TableName), + ) + m.debug(ctx, query) + _, err = exec.ExecContext(ctx, query, applied.ID, applied.Checksum, applied.ExecutionTimeInMillis, applied.AppliedAt) + if err != nil { + msg := "failed to mark migration as applied" + m.error(ctx, err, msg, fields...) + return fmt.Errorf("%s: %w", msg, err) + } + m.info(ctx, "marked as applied", fields...) + return nil +} + // Verify returns a list of [VerificationError]s with warnings for any migrations that: // // - Are marked as applied in the database table but do not exist in the diff --git a/migrator_test.go b/migrator_test.go index 582b674..1b5ba4d 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -340,6 +340,66 @@ func TestAppliedAndPlanWithoutMigrationsTable(t *testing.T) { assert.Nil(t, err) } +func TestApplyNoTransactionMigration(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := pgmigrate.NewTestLogger(t) + err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + migrations := []pgmigrate.Migration{ + { + ID: "0001_initial", + SQL: "CREATE TABLE users (id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, name TEXT);", + NoTransaction: true, + }, + } + migrator := pgmigrate.NewMigrator(migrations) + migrator.Logger = logger + verrs, err := migrator.Migrate(ctx, db) + assert.Nil(t, err) + assert.Equal(t, nil, verrs) + + applied, err := migrator.Applied(ctx, db) + assert.Nil(t, err) + assert.Equal(t, 1, len(applied)) + check.Equal(t, migrations[0].ID, applied[0].ID) + check.Equal(t, migrations[0].MD5(), applied[0].Checksum) + return nil + }) + assert.Nil(t, err) +} + +func TestApplyCreateIndexConcurrently(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := pgmigrate.NewTestLogger(t) + err := withdb.WithDB(ctx, "pgx", func(db *sql.DB) error { + migrations := []pgmigrate.Migration{ + { + ID: "0001_create_table", + SQL: "CREATE TABLE users (id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, name TEXT);", + }, + { + ID: "0002_create_index", + SQL: "-- pgmigrate: no-transaction\nCREATE INDEX CONCURRENTLY idx_users_name ON users (name);", + NoTransaction: true, + }, + } + migrator := pgmigrate.NewMigrator(migrations) + migrator.Logger = logger + verrs, err := migrator.Migrate(ctx, db) + assert.Nil(t, err) + assert.Equal(t, nil, verrs) + + applied, err := migrator.Applied(ctx, db) + assert.Nil(t, err) + assert.Equal(t, 2, len(applied)) + check.Equal(t, "0001_create_table", applied[0].ID) + check.Equal(t, "0002_create_index", applied[1].ID) + return nil + }) + assert.Nil(t, err) +} + // By default, pgmigrate will use the [DefaultTableName] table to // keep track of migrations. Because this is a fully qualified table name, // including a schema prefix, pgmigrate will not be affected by migrations diff --git a/pgmigrate.go b/pgmigrate.go index 1dd6fb7..82c5ce6 100644 --- a/pgmigrate.go +++ b/pgmigrate.go @@ -41,6 +41,7 @@ func Load(filesystem fs.FS) ([]Migration, error) { return err } migration.SQL = string(data) + migration.NoTransaction = parseNoTransaction(migration.SQL) migrations = append(migrations, migration) return nil }); err != nil {