Skip to content

Commit d474e2e

Browse files
authored
Merge pull request #2 from openframebox/sqlite-driver
feat: add sqlite driver
2 parents fe8c5ad + 3247375 commit d474e2e

4 files changed

Lines changed: 466 additions & 6 deletions

File tree

driver_sqlite.go

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
package gomigration
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"time"
8+
9+
_ "github.com/ncruces/go-sqlite3/driver"
10+
_ "github.com/ncruces/go-sqlite3/embed"
11+
)
12+
13+
// SqliteDriver is a driver for sqlite
14+
type SqliteDriver struct {
15+
db *sql.DB
16+
migrationTableName string
17+
}
18+
19+
// NewSqliteDriver creates a new SqliteDriver
20+
func NewSqliteDriver(
21+
database string,
22+
) (*SqliteDriver, error) {
23+
// Open database
24+
db, err := sql.Open("sqlite3", database)
25+
if err != nil {
26+
return nil, err
27+
}
28+
29+
// Ping database
30+
if err := db.Ping(); err != nil {
31+
return nil, err
32+
}
33+
34+
// Return the driver with a default table name
35+
return &(SqliteDriver{db, "migrations"}), nil
36+
}
37+
38+
// Close closes the database connection
39+
func (d *SqliteDriver) Close() error {
40+
if d.db != nil {
41+
if err := d.db.Close(); err != nil {
42+
return err
43+
}
44+
}
45+
46+
return nil
47+
}
48+
49+
// SetMigrationTableName sets the migration table name of the migration tracking table
50+
func (d *SqliteDriver) SetMigrationTableName(name string) {
51+
if name == "" {
52+
name = "migrations"
53+
}
54+
d.migrationTableName = name
55+
}
56+
57+
// CreateMigrationTable creates the migration tracking table
58+
func (d *SqliteDriver) CreateMigrationsTable(ctx context.Context) error {
59+
query := fmt.Sprintf(`
60+
CREATE TABLE IF NOT EXISTS %s (
61+
name VARCHAR(255) PRIMARY KEY NOT NULL,
62+
executed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
63+
);
64+
`, d.migrationTableName)
65+
66+
_, err := d.db.ExecContext(ctx, query)
67+
return err
68+
}
69+
70+
// GetExecutedMigrations returns a list of previously executed migrations
71+
func (d *SqliteDriver) GetExecutedMigrations(ctx context.Context, reverse bool) ([]ExecutedMigration, error) {
72+
order := "ASC"
73+
if reverse {
74+
order = "DESC"
75+
}
76+
77+
query := fmt.Sprintf(`SELECT name, executed_at FROM %s ORDER BY name %s`, d.migrationTableName, order)
78+
rows, err := d.db.QueryContext(ctx, query)
79+
if err != nil {
80+
return nil, err
81+
}
82+
defer rows.Close()
83+
84+
var migrations []ExecutedMigration
85+
for rows.Next() {
86+
var name string
87+
var executedAt time.Time
88+
if err := rows.Scan(&name, &executedAt); err != nil {
89+
return nil, err
90+
}
91+
migrations = append(migrations, ExecutedMigration{Name: name, ExecutedAt: executedAt})
92+
}
93+
94+
return migrations, rows.Err()
95+
}
96+
97+
// CleanDatabase drops all table from the current database.
98+
func (d *SqliteDriver) CleanDatabase(ctx context.Context) error {
99+
// Disable FK checks temporarily
100+
_, err := d.db.ExecContext(ctx, `PRAGMA foreign_keys = OFF;`)
101+
if err != nil {
102+
return fmt.Errorf("failed to disable FK checks: %w", err)
103+
}
104+
105+
// Get all user-defined table names (excluding sqlite internal tables)
106+
rows, err := d.db.QueryContext(ctx, `
107+
SELECT name
108+
FROM sqlite_master
109+
WHERE type = 'table'
110+
AND name NOT LIKE 'sqlite_%';
111+
`)
112+
if err != nil {
113+
return fmt.Errorf("failed to query tables: %w", err)
114+
}
115+
defer rows.Close()
116+
117+
var tableNames []string
118+
for rows.Next() {
119+
var table string
120+
if err := rows.Scan(&table); err != nil {
121+
return fmt.Errorf("failed to scan table name: %w", err)
122+
}
123+
tableNames = append(tableNames, fmt.Sprintf(`"%s"`, table))
124+
}
125+
126+
// No tables to drop
127+
if len(tableNames) == 0 {
128+
// Re-enable FK checks before returning
129+
_, _ = d.db.ExecContext(ctx, `PRAGMA foreign_keys = ON;`)
130+
return nil
131+
}
132+
133+
// Drop all tables (SQLite doesn't support dropping multiple tables in one statement)
134+
for _, tableName := range tableNames {
135+
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)
136+
_, err = d.db.ExecContext(ctx, dropSQL)
137+
if err != nil {
138+
return fmt.Errorf("failed to drop table %s: %w", tableName, err)
139+
}
140+
}
141+
142+
// Re-enable FK checks
143+
_, err = d.db.ExecContext(ctx, `PRAGMA foreign_keys = ON;`)
144+
if err != nil {
145+
return fmt.Errorf("failed to re-enable FK checks: %w", err)
146+
}
147+
148+
return nil
149+
}
150+
151+
// ApplyMigrations applies a batch of "up" migrations with optional callbacks.
152+
func (d *SqliteDriver) ApplyMigrations(
153+
ctx context.Context,
154+
migrations []Migration,
155+
onRunning func(migration *Migration),
156+
onSuccess func(migration *Migration),
157+
onFailed func(migration *Migration, err error),
158+
) error {
159+
for i := range migrations {
160+
mig := migrations[i]
161+
162+
if onRunning != nil {
163+
onRunning(&mig)
164+
}
165+
166+
// Execute the migration SQL
167+
if err := d.executeMigrationSQL(ctx, mig.UpScript()); err != nil {
168+
if onFailed != nil {
169+
onFailed(&mig, err)
170+
}
171+
return fmt.Errorf("failed to apply migration %s: %w", mig.Name(), err)
172+
}
173+
174+
// Record the migration
175+
if err := d.insertExecutedMigration(ctx, mig.Name(), time.Now()); err != nil {
176+
if onFailed != nil {
177+
onFailed(&mig, err)
178+
}
179+
return fmt.Errorf("failed to record migration %s: %w", mig.Name(), err)
180+
}
181+
182+
if onSuccess != nil {
183+
onSuccess(&mig)
184+
}
185+
}
186+
return nil
187+
}
188+
189+
// UnapplyMigrations rolls back a batch of "down" migrations with optional callbacks.
190+
func (d *SqliteDriver) UnapplyMigrations(
191+
ctx context.Context,
192+
migrations []Migration,
193+
onRunning func(migration *Migration),
194+
onSuccess func(migration *Migration),
195+
onFailed func(migration *Migration, err error),
196+
) error {
197+
for i := range migrations {
198+
mig := migrations[i]
199+
200+
if onRunning != nil {
201+
onRunning(&mig)
202+
}
203+
204+
// Execute the down migration SQL
205+
if err := d.executeMigrationSQL(ctx, mig.DownScript()); err != nil {
206+
if onFailed != nil {
207+
onFailed(&mig, err)
208+
}
209+
return fmt.Errorf("failed to unapply migration %s: %w", mig.Name(), err)
210+
}
211+
212+
// Remove migration record from tracking table
213+
if err := d.removeExecutedMigration(ctx, mig.Name()); err != nil {
214+
if onFailed != nil {
215+
onFailed(&mig, err)
216+
}
217+
return fmt.Errorf("failed to remove migration record %s: %w", mig.Name(), err)
218+
}
219+
220+
if onSuccess != nil {
221+
onSuccess(&mig)
222+
}
223+
}
224+
return nil
225+
}
226+
227+
// executeMigrationSQL runs a raw SQL migration script.
228+
func (d *SqliteDriver) executeMigrationSQL(ctx context.Context, sql string) error {
229+
if sql == "" {
230+
return nil
231+
}
232+
_, err := d.db.ExecContext(ctx, sql)
233+
return err
234+
}
235+
236+
// insertExecutedMigration logs a migration into the migration tracking table.
237+
func (d *SqliteDriver) insertExecutedMigration(ctx context.Context, name string, executedAt time.Time) error {
238+
query := fmt.Sprintf(`INSERT INTO %s (name, executed_at) VALUES (?, ?)`, d.migrationTableName)
239+
_, err := d.db.ExecContext(ctx, query, name, executedAt)
240+
return err
241+
}
242+
243+
// removeExecutedMigration deletes a migration record from the migration table.
244+
func (d *SqliteDriver) removeExecutedMigration(ctx context.Context, name string) error {
245+
query := fmt.Sprintf(`DELETE FROM %s WHERE name = ?`, d.migrationTableName)
246+
_, err := d.db.ExecContext(ctx, query, name)
247+
return err
248+
}

0 commit comments

Comments
 (0)