Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func New(ctx context.Context, logger log.Logger, config DatabaseConfig) (*sql.DB
if err != nil {
return nil, fmt.Errorf("connecting to postgres: %w", err)
}
return ApplyConnectionsConfig(db, &config.Postgres.Connections, logger), nil
return ApplyPostgresConnectionsConfig(db, &config.Postgres.Connections, logger), nil
}

return nil, ErrMissingConfig
Expand Down Expand Up @@ -112,3 +112,27 @@ func ApplyConnectionsConfig(db *sql.DB, connections *ConnectionsConfig, logger l

return db
}

// ApplyPostgresConnectionsConfig applies connection pool settings with safe defaults
// for Postgres/AlloyDB. If any value in the provided config is zero, the corresponding
// default from DefaultPostgresConnectionsConfig is used. This ensures all services get
// failover-safe pool settings even if they don't explicitly configure them.
func ApplyPostgresConnectionsConfig(db *sql.DB, connections *ConnectionsConfig, logger log.Logger) *sql.DB {
defaults := DefaultPostgresConnectionsConfig()

applied := *connections
if applied.MaxOpen <= 0 {
applied.MaxOpen = defaults.MaxOpen
}
if applied.MaxIdle <= 0 {
applied.MaxIdle = defaults.MaxIdle
}
if applied.MaxLifetime <= 0 {
applied.MaxLifetime = defaults.MaxLifetime
}
if applied.MaxIdleTime <= 0 {
applied.MaxIdleTime = defaults.MaxIdleTime
}

return ApplyConnectionsConfig(db, &applied, logger)
}
49 changes: 49 additions & 0 deletions database/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ package database_test

import (
"bytes"
"database/sql"
"errors"
"os"
"testing"
"time"

gomysql "github.com/go-sql-driver/mysql"
"github.com/jackc/pgx/v5/pgconn"
"github.com/moov-io/base/database"
"github.com/moov-io/base/log"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -121,6 +124,52 @@ func TestDataTooLong(t *testing.T) {
}
}

func TestApplyPostgresConnectionsConfig_Defaults(t *testing.T) {
// When no config values are set, defaults should be applied
db, err := sql.Open("txdb", "TestApplyPostgresConnectionsConfig_Defaults")
if err != nil {
t.Skip("skipping test without txdb driver")
}
defer db.Close()

connections := &database.ConnectionsConfig{}
database.ApplyPostgresConnectionsConfig(db, connections, log.NewTestLogger())

defaults := database.DefaultPostgresConnectionsConfig()
stats := db.Stats()
require.Equal(t, defaults.MaxOpen, stats.MaxOpenConnections)
}

func TestApplyPostgresConnectionsConfig_Overrides(t *testing.T) {
// When config values are set, they should override defaults
db, err := sql.Open("txdb", "TestApplyPostgresConnectionsConfig_Overrides")
if err != nil {
t.Skip("skipping test without txdb driver")
}
defer db.Close()

connections := &database.ConnectionsConfig{
MaxOpen: 10,
MaxIdle: 3,
MaxLifetime: time.Minute,
MaxIdleTime: time.Second * 15,
}
database.ApplyPostgresConnectionsConfig(db, connections, log.NewTestLogger())

stats := db.Stats()
require.Equal(t, 10, stats.MaxOpenConnections)
}

func TestDefaultPostgresConnectionsConfig(t *testing.T) {
defaults := database.DefaultPostgresConnectionsConfig()
require.Greater(t, defaults.MaxOpen, 0)
require.Greater(t, defaults.MaxIdle, 0)
require.Greater(t, defaults.MaxLifetime, time.Duration(0))
require.Greater(t, defaults.MaxIdleTime, time.Duration(0))
// MaxIdleTime should be shorter than MaxLifetime
require.Less(t, defaults.MaxIdleTime, defaults.MaxLifetime)
}

func TestConnectionsConfigOrder(t *testing.T) {
bs, err := os.ReadFile("database.go")
require.NoError(t, err)
Expand Down
13 changes: 13 additions & 0 deletions database/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ type ConnectionsConfig struct {
MaxIdleTime time.Duration
}

// DefaultPostgresConnectionsConfig returns connection pool defaults tuned for
// database failover recovery (e.g. AlloyDB maintenance switchovers). Short
// lifetimes and idle times ensure dead connections are evicted quickly so
// the pool re-establishes connections to the new primary.
func DefaultPostgresConnectionsConfig() ConnectionsConfig {
return ConnectionsConfig{
MaxOpen: 25,
MaxIdle: 5,
MaxLifetime: 5 * time.Minute,
MaxIdleTime: 30 * time.Second,
}
}

type RetryConfig struct {
MaxAttempts int
MinDuration time.Duration
Expand Down
192 changes: 138 additions & 54 deletions database/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"database/sql"
"errors"
"fmt"
"io"
"net"
"strings"
"time"

"cloud.google.com/go/alloydbconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/moov-io/base/log"
)
Expand All @@ -23,67 +25,50 @@ const (
)

func postgresConnection(ctx context.Context, logger log.Logger, config PostgresConfig, databaseName string) (*sql.DB, error) {
var connStr string
if config.Alloy != nil {
c, err := getAlloyDBConnectorConnStr(ctx, config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating alloydb connection: %w", err).Err()
}
connStr = c
} else {
c, err := getPostgresConnStr(config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating postgres connection: %w", err).Err()
}
connStr = c
poolConfig, err := buildPgxPoolConfig(ctx, config, databaseName)
if err != nil {
return nil, logger.LogErrorf("building pgx pool config: %w", err).Err()
}

db, err := sql.Open("pgx", connStr)
// HealthCheckPeriod makes pgxpool ping idle connections in the background.
// Dead connections (e.g. from an AlloyDB switchover) are evicted before
// the application ever sees them.
poolConfig.HealthCheckPeriod = 1 * time.Second

pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, logger.LogErrorf("opening database: %w", err).Err()
return nil, logger.LogErrorf("creating pgx pool: %w", err).Err()
}

err = db.Ping()
err = pool.Ping(ctx)
if err != nil {
_ = db.Close()
pool.Close()
return nil, logger.LogErrorf("connecting to database: %w", err).Err()
}

// Wrap the pgxpool in a *sql.DB so the rest of the codebase doesn't change.
// pgxpool manages the real pool (with health checks); database/sql pool
// settings are applied on top via ApplyPostgresConnectionsConfig.
db := stdlib.OpenDBFromPool(pool)

return db, nil
}

func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) {
url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName)

params := ""

if config.TLS != nil {
if len(config.TLS.Mode) < 1 {
config.TLS.Mode = "verify-full"
}

params += "sslmode=" + config.TLS.Mode

if len(config.TLS.CACertFile) > 0 {
params += "&sslrootcert=" + config.TLS.CACertFile
}

if len(config.TLS.ClientCertFile) > 0 {
params += "&sslcert=" + config.TLS.ClientCertFile
}

if len(config.TLS.ClientKeyFile) > 0 {
params += "&sslkey=" + config.TLS.ClientKeyFile
}
func buildPgxPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) {
if config.Alloy != nil {
return buildAlloyDBPoolConfig(ctx, config, databaseName)
}

connStr := fmt.Sprintf("%s?%s", url, params)
return connStr, nil
connStr, err := getPostgresConnStr(config, databaseName)
if err != nil {
return nil, err
}
return pgxpool.ParseConfig(connStr)
}

func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, databaseName string) (string, error) {
func buildAlloyDBPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) {
if config.Alloy == nil {
return "", fmt.Errorf("missing alloy config")
return nil, fmt.Errorf("missing alloy config")
}

var dialer *alloydbconn.Dialer
Expand All @@ -92,7 +77,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
if config.Alloy.UseIAM {
d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN())
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
return nil, fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
Expand All @@ -104,7 +89,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
} else {
d, err := alloydbconn.NewDialer(ctx)
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
return nil, fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
Expand All @@ -114,24 +99,49 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
)
}

// TODO
//cleanup := func() error { return d.Close() }

connConfig, err := pgx.ParseConfig(dsn)
poolConfig, err := pgxpool.ParseConfig(dsn)
if err != nil {
return "", fmt.Errorf("failed to parse pgx config: %v", err)
return nil, fmt.Errorf("failed to parse pgx pool config: %v", err)
}

var connOptions []alloydbconn.DialOption
if config.Alloy.UsePSC {
connOptions = append(connOptions, alloydbconn.WithPSC())
}

connConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
poolConfig.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
return dialer.Dial(ctx, config.Alloy.InstanceURI, connOptions...)
}

connStr := stdlib.RegisterConnConfig(connConfig)
return poolConfig, nil
}

func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) {
url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName)

params := ""

if config.TLS != nil {
if len(config.TLS.Mode) < 1 {
config.TLS.Mode = "verify-full"
}

params += "sslmode=" + config.TLS.Mode

if len(config.TLS.CACertFile) > 0 {
params += "&sslrootcert=" + config.TLS.CACertFile
}

if len(config.TLS.ClientCertFile) > 0 {
params += "&sslcert=" + config.TLS.ClientCertFile
}

if len(config.TLS.ClientKeyFile) > 0 {
params += "&sslkey=" + config.TLS.ClientKeyFile
}
}

connStr := fmt.Sprintf("%s?%s", url, params)
return connStr, nil
}

Expand Down Expand Up @@ -164,3 +174,77 @@ func PostgresDeadlockFound(err error) bool {

return strings.Contains(err.Error(), postgresErrDeadlockFound)
}

// IsRetryablePostgresError returns true if the error is a transient connection-level
// error that is safe to retry. This covers the errors seen during AlloyDB maintenance
// switchovers and other transient network failures.
func IsRetryablePostgresError(err error) bool {
if err == nil {
return false
}

// PostgreSQL error codes indicating the server is shutting down or unavailable
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
switch pgErr.Code {
case "57P01", "57P02", "57P03": // admin_shutdown, crash_shutdown, cannot_connect_now
return true
case "08000", "08001", "08003", "08004", "08006": // connection_exception class
return true
Comment on lines +190 to +193
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to see documentation on these error cases as being valid to retry. Unsure I trust an AI here with possible data corruption if its incorrect. Or just remove the retry stuff and put it into a different PR.

}
Comment thread
adamdecaf marked this conversation as resolved.
return false
}

// Network-level errors: connection reset, broken pipe, EOF, etc.
// These occur when the TCP connection is severed during a switchover.
var netErr *net.OpError
if errors.As(err, &netErr) {
return true
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
if errors.Is(err, context.DeadlineExceeded) {
return false // don't retry if the caller's context timed out
}

// pgx wraps connection errors with these messages
msg := err.Error()
if strings.Contains(msg, "connection reset by peer") ||
strings.Contains(msg, "broken pipe") ||
strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "unexpected EOF") ||
strings.Contains(msg, "conn closed") {
return true
}

return false
}

// RetryPostgres executes fn up to maxAttempts times, retrying on transient
// connection errors. This is intended for use around individual database
// operations to survive brief outages like AlloyDB maintenance switchovers.
func RetryPostgres(ctx context.Context, maxAttempts int, fn func() error) error {
if maxAttempts <= 0 {
maxAttempts = 3
}
var err error
for attempt := 0; attempt < maxAttempts; attempt++ {
err = fn()
if err == nil {
return nil
}
if !IsRetryablePostgresError(err) {
return err
}
if attempt < maxAttempts-1 {
backoff := time.Duration(attempt+1) * 200 * time.Millisecond
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backoff is a bad idea in this case as you'll drastically increase the wait, along with it not being configurable. 200ms is an ETERNITY to a program, and is noticeable by a user. Theirs also no variance in this interface, so they will all slam at the same time.

select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
}
return err
}
Loading
Loading