Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions admin/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"

_ "modernc.org/sqlite" // register sqlite driver
)
Expand All @@ -24,24 +25,17 @@ func (s *Store) DB() *sql.DB {
// Open creates a new Store with the given SQLite database path.
// It configures WAL mode, busy timeout, and foreign keys, then runs
// any pending schema migrations.
//
// Pragmas are passed via the DSN so that every connection in the
// database/sql pool receives them, not just the first one.
func Open(ctx context.Context, dbPath string) (*Store, error) {
db, err := sql.Open("sqlite", dbPath)
dsn := buildDSN(dbPath)

db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}

pragmas := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA busy_timeout=5000",
"PRAGMA foreign_keys=ON",
}
for _, p := range pragmas {
if _, err := db.ExecContext(ctx, p); err != nil {
_ = db.Close() // best-effort cleanup; primary error is the pragma failure
return nil, fmt.Errorf("setting pragma %q: %w", p, err)
}
}

if err := db.PingContext(ctx); err != nil {
_ = db.Close() // best-effort cleanup; primary error is the ping failure
return nil, fmt.Errorf("pinging database: %w", err)
Expand All @@ -56,6 +50,17 @@ func Open(ctx context.Context, dbPath string) (*Store, error) {
return s, nil
}

// buildDSN constructs a SQLite DSN with per-connection pragmas.
// Using _pragma query parameters ensures every pooled connection
// gets WAL mode, a busy timeout, and foreign key enforcement.
func buildDSN(dbPath string) string {
v := url.Values{}
v.Add("_pragma", "journal_mode(WAL)")
v.Add("_pragma", "busy_timeout(5000)")
v.Add("_pragma", "foreign_keys(1)")
return dbPath + "?" + v.Encode()
}

// Close closes the database connection.
func (s *Store) Close() error {
if err := s.db.Close(); err != nil {
Expand Down
60 changes: 60 additions & 0 deletions admin/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"path/filepath"
"sort"
"strings"
"testing"
)

Expand Down Expand Up @@ -136,3 +137,62 @@ func TestOpen_InvalidPath_ReturnsError(t *testing.T) {
t.Error("expected error, got nil")
}
}

func TestOpen_PragmasApplyToAllPoolConnections(t *testing.T) {
t.Parallel()

// Arrange — open store and force multiple connections in the pool.
st := openTestStore(t)
db := st.DB()
db.SetMaxOpenConns(4)

ctx := context.Background()

// Act — grab several raw connections and check pragmas on each.
for i := range 4 {
conn, err := db.Conn(ctx)
if err != nil {
t.Fatalf("conn %d: %v", i, err)
}

var timeout int
if err := conn.QueryRowContext(ctx, "PRAGMA busy_timeout").Scan(&timeout); err != nil {
t.Fatalf("conn %d: querying busy_timeout: %v", i, err)
}
if timeout != 5000 {
t.Errorf("conn %d: busy_timeout = %d, want 5000", i, timeout)
}

var fk int
if err := conn.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&fk); err != nil {
t.Fatalf("conn %d: querying foreign_keys: %v", i, err)
}
if fk != 1 {
t.Errorf("conn %d: foreign_keys = %d, want 1", i, fk)
}

conn.Close()
}
}

func TestBuildDSN_ContainsPragmas(t *testing.T) {
t.Parallel()

dsn := buildDSN("/tmp/test.db")

// url.Values.Encode() percent-encodes parentheses, so check
// the encoded form that the driver actually receives.
for _, want := range []string{
"_pragma=journal_mode%28WAL%29",
"_pragma=busy_timeout%285000%29",
"_pragma=foreign_keys%281%29",
} {
if !strings.Contains(dsn, want) {
t.Errorf("DSN %q missing pragma %q", dsn, want)
}
}

if !strings.HasPrefix(dsn, "/tmp/test.db?") {
t.Errorf("DSN %q does not start with expected path", dsn)
}
}
Loading