From b43a34155451ea5f98eb10a12b42c840855e08cb Mon Sep 17 00:00:00 2001 From: Matt Robenolt Date: Fri, 20 Mar 2026 20:18:17 -0700 Subject: [PATCH] pgconn: add optional SCRAM-SHA-256 SaltedPassword cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RFC 5802 §5.1 and RFC 7677 §4 both permit caching derived SCRAM key material to avoid repeated PBKDF2 computation: RFC 5802 §5.1: "a client implementation MAY cache ClientKey&ServerKey (or just SaltedPassword) for later reauthentication to the same service, as it is likely that the server is going to advertise the same salt value upon reauthentication." RFC 7677 §4: "This computational cost can be avoided by caching the ClientKey (assuming the Salt and hash iteration-count is stable)." Add a ScramDeriveCache interface on Config that, when set, caches the 32-byte SaltedPassword keyed by an opaque SHA-256 fingerprint of (password, salt, iterations). On subsequent connections with the same verifier, PBKDF2 is skipped entirely. Cache entries are invalidated on authentication failure when cached material was used, so a password change or salt rotation cannot produce stale hits. Includes SimpleScramDeriveCache, a mutex+map implementation suitable for most use cases. The interface allows callers to provide their own LRU or sharded implementations. --- pgconn/auth_scram.go | 56 ++++++++++++++++++++++------ pgconn/auth_scram_test.go | 2 +- pgconn/config.go | 7 ++++ pgconn/scram_cache.go | 76 ++++++++++++++++++++++++++++++++++++++ pgconn/scram_cache_test.go | 59 +++++++++++++++++++++++++++++ 5 files changed, 187 insertions(+), 13 deletions(-) create mode 100644 pgconn/scram_cache.go create mode 100644 pgconn/scram_cache_test.go diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index f59d39c4e..f9defcdc3 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -93,6 +93,14 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } + if cache := c.config.ScramDeriveCache; cache != nil && sc.password != "" { + if sp, ok := cache.Get(sc.deriveFingerprint()); ok { + sc.saltedPassword = sp + sc.hasSaltedPassword = true + sc.saltedPasswordFromCache = true + } + } + // Send client-final-message in a SASLResponse saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), @@ -106,9 +114,25 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { // Receive server-final-message payload in an AuthenticationSASLFinal. saslFinal, err := c.rxSASLFinal() if err != nil { + scramInvalidateDeriveCache(c, sc) + return err + } + err = sc.recvServerFinalMessage(saslFinal.Data) + if err != nil { + scramInvalidateDeriveCache(c, sc) return err } - return sc.recvServerFinalMessage(saslFinal.Data) + if cache := c.config.ScramDeriveCache; cache != nil && !sc.saltedPasswordFromCache && sc.hasSaltedPassword { + cache.Put(sc.deriveFingerprint(), sc.saltedPassword) + } + return nil +} + +func scramInvalidateDeriveCache(c *PgConn, sc *scramClient) { + if c.config.ScramDeriveCache == nil || !sc.saltedPasswordFromCache { + return + } + c.config.ScramDeriveCache.Delete(sc.deriveFingerprint()) } func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { @@ -172,10 +196,12 @@ type scramClient struct { serverFirstMessage []byte clientAndServerNonce []byte salt []byte - iterations int + iterations uint64 - saltedPassword []byte - authMessage []byte + saltedPassword ScramSaltedPassword + hasSaltedPassword bool + saltedPasswordFromCache bool + authMessage []byte } func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { @@ -279,8 +305,8 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { return fmt.Errorf("invalid SCRAM salt received from server: %w", err) } - sc.iterations, err = strconv.Atoi(string(iterationsStr)) - if err != nil || sc.iterations <= 0 { + sc.iterations, err = strconv.ParseUint(string(iterationsStr), 10, 64) + if err != nil || sc.iterations == 0 { return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) } @@ -310,14 +336,20 @@ func (sc *scramClient) clientFinalMessage() string { channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput) clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce) - var err error - sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) - if err != nil { - panic(err) // This should never happen. + if !sc.hasSaltedPassword { + sp, err := pbkdf2.Key(sha256.New, sc.password, sc.salt, int(sc.iterations), 32) + if err != nil { + panic(err) // This should never happen. + } + if len(sp) != 32 { + panic("unexpected PBKDF2 output length") + } + copy(sc.saltedPassword[:], sp) + sc.hasSaltedPassword = true } sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) - clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + clientProof := computeClientProof(sc.saltedPassword[:], sc.authMessage) return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) } @@ -329,7 +361,7 @@ func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { serverSignature := serverFinalMessage[2:] - if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword[:], sc.authMessage)) { return errors.New("invalid SCRAM ServerSignature received from server") } diff --git a/pgconn/auth_scram_test.go b/pgconn/auth_scram_test.go index dcb500f8f..e50592814 100644 --- a/pgconn/auth_scram_test.go +++ b/pgconn/auth_scram_test.go @@ -408,7 +408,7 @@ func TestScramClientRecvServerFinalMessage(t *testing.T) { sc := setup(t) - validSignature := computeServerSignature(sc.saltedPassword, sc.authMessage) + validSignature := computeServerSignature(sc.saltedPassword[:], sc.authMessage) err := sc.recvServerFinalMessage(append([]byte("v="), validSignature...)) require.NoError(t, err) }) diff --git a/pgconn/config.go b/pgconn/config.go index dff550953..939bfe1bb 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -100,6 +100,13 @@ type Config struct { // Valid values: "disable", "prefer", "require". Defaults to "prefer". ChannelBinding string + // ScramDeriveCache, if non-nil, caches SCRAM-SHA-256 PBKDF2 output (the 32-byte salted password) + // under opaque [ScramCacheFingerprint] values computed by pgx during SCRAM. This can reduce per-connection CPU when many + // connections share the same verifier. Entries are invalidated on authentication failure when + // the cached material was used. The same cache instance is shared when [Config.Copy] copies + // the interface value (typical for connection pools). + ScramDeriveCache ScramDeriveCache + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } diff --git a/pgconn/scram_cache.go b/pgconn/scram_cache.go new file mode 100644 index 000000000..9497f986c --- /dev/null +++ b/pgconn/scram_cache.go @@ -0,0 +1,76 @@ +package pgconn + +import ( + "crypto/sha256" + "encoding/binary" + "sync" +) + +// ScramCacheFingerprint is an opaque cache entry identifier for [ScramDeriveCache]. +// pgx computes it while authenticating; implementations must use only the values +// passed to Get, Put, and Delete (do not construct fingerprints yourself). +type ScramCacheFingerprint [32]byte + +// ScramSaltedPassword is the 32-byte PBKDF2 output cached by [ScramDeriveCache]. +type ScramSaltedPassword [32]byte + +func (sc *scramClient) deriveFingerprint() ScramCacheFingerprint { + h := sha256.New() + h.Write([]byte(sc.password)) + h.Write([]byte{0}) + h.Write(sc.salt) + h.Write([]byte{0}) + var iter [8]byte + binary.BigEndian.PutUint64(iter[:], sc.iterations) + h.Write(iter[:]) + var fp ScramCacheFingerprint + h.Sum(fp[:0]) + return fp +} + +// ScramDeriveCache stores the 32-byte SCRAM salted password (PBKDF2 output) so new connections +// can skip PBKDF2 when the server verifier is unchanged. +// +// Method arguments never include passwords or raw salts—only opaque [ScramCacheFingerprint] +// values produced by pgx during SCRAM. +type ScramDeriveCache interface { + Get(ScramCacheFingerprint) (ScramSaltedPassword, bool) + Put(ScramCacheFingerprint, ScramSaltedPassword) + Delete(ScramCacheFingerprint) +} + +// SimpleScramDeriveCache is a small mutex-backed map implementation of [ScramDeriveCache] +// for tests and applications that do not need LRU eviction. +type SimpleScramDeriveCache struct { + mu sync.Mutex + m map[ScramCacheFingerprint]ScramSaltedPassword +} + +// NewSimpleScramDeriveCache returns an empty [SimpleScramDeriveCache]. +func NewSimpleScramDeriveCache() *SimpleScramDeriveCache { + return &SimpleScramDeriveCache{ + m: make(map[ScramCacheFingerprint]ScramSaltedPassword), + } +} + +// Get implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Get(fp ScramCacheFingerprint) (ScramSaltedPassword, bool) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[fp] + return v, ok +} + +// Put implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Put(fp ScramCacheFingerprint, sp ScramSaltedPassword) { + s.mu.Lock() + defer s.mu.Unlock() + s.m[fp] = sp +} + +// Delete implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Delete(fp ScramCacheFingerprint) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, fp) +} diff --git a/pgconn/scram_cache_test.go b/pgconn/scram_cache_test.go new file mode 100644 index 000000000..ecef25879 --- /dev/null +++ b/pgconn/scram_cache_test.go @@ -0,0 +1,59 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testScramClient(password string, salt []byte, iterations uint64) *scramClient { + return &scramClient{ + password: password, + salt: salt, + iterations: iterations, + } +} + +func TestScramDeriveFingerprintStable(t *testing.T) { + t.Parallel() + sc := testScramClient("pw", []byte{1, 2, 3}, 4096) + fp1 := sc.deriveFingerprint() + fp2 := sc.deriveFingerprint() + assert.Equal(t, fp1, fp2) +} + +func TestScramDeriveFingerprintDistinct(t *testing.T) { + t.Parallel() + a := testScramClient("a", []byte{1}, 1).deriveFingerprint() + b := testScramClient("b", []byte{1}, 1).deriveFingerprint() + assert.NotEqual(t, a, b) +} + +func TestSimpleScramDeriveCache(t *testing.T) { + t.Parallel() + + c := NewSimpleScramDeriveCache() + fp := testScramClient("pw", []byte{1, 2, 3}, 4096).deriveFingerprint() + + var sp ScramSaltedPassword + for i := range sp { + sp[i] = byte(i) + } + + _, ok := c.Get(fp) + assert.False(t, ok) + + c.Put(fp, sp) + got, ok := c.Get(fp) + require.True(t, ok) + assert.Equal(t, sp, got) + got[0] ^= 0xff + got2, ok := c.Get(fp) + require.True(t, ok) + assert.Equal(t, sp, got2) + + c.Delete(fp) + _, ok = c.Get(fp) + assert.False(t, ok) +}