diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index f59d39c4e..f9defcdc3 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -93,6 +93,14 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } + if cache := c.config.ScramDeriveCache; cache != nil && sc.password != "" { + if sp, ok := cache.Get(sc.deriveFingerprint()); ok { + sc.saltedPassword = sp + sc.hasSaltedPassword = true + sc.saltedPasswordFromCache = true + } + } + // Send client-final-message in a SASLResponse saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), @@ -106,9 +114,25 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { // Receive server-final-message payload in an AuthenticationSASLFinal. saslFinal, err := c.rxSASLFinal() if err != nil { + scramInvalidateDeriveCache(c, sc) + return err + } + err = sc.recvServerFinalMessage(saslFinal.Data) + if err != nil { + scramInvalidateDeriveCache(c, sc) return err } - return sc.recvServerFinalMessage(saslFinal.Data) + if cache := c.config.ScramDeriveCache; cache != nil && !sc.saltedPasswordFromCache && sc.hasSaltedPassword { + cache.Put(sc.deriveFingerprint(), sc.saltedPassword) + } + return nil +} + +func scramInvalidateDeriveCache(c *PgConn, sc *scramClient) { + if c.config.ScramDeriveCache == nil || !sc.saltedPasswordFromCache { + return + } + c.config.ScramDeriveCache.Delete(sc.deriveFingerprint()) } func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { @@ -172,10 +196,12 @@ type scramClient struct { serverFirstMessage []byte clientAndServerNonce []byte salt []byte - iterations int + iterations uint64 - saltedPassword []byte - authMessage []byte + saltedPassword ScramSaltedPassword + hasSaltedPassword bool + saltedPasswordFromCache bool + authMessage []byte } func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { @@ -279,8 +305,8 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { return fmt.Errorf("invalid SCRAM salt received from server: %w", err) } - sc.iterations, err = strconv.Atoi(string(iterationsStr)) - if err != nil || sc.iterations <= 0 { + sc.iterations, err = strconv.ParseUint(string(iterationsStr), 10, 64) + if err != nil || sc.iterations == 0 { return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) } @@ -310,14 +336,20 @@ func (sc *scramClient) clientFinalMessage() string { channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput) clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce) - var err error - sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) - if err != nil { - panic(err) // This should never happen. + if !sc.hasSaltedPassword { + sp, err := pbkdf2.Key(sha256.New, sc.password, sc.salt, int(sc.iterations), 32) + if err != nil { + panic(err) // This should never happen. + } + if len(sp) != 32 { + panic("unexpected PBKDF2 output length") + } + copy(sc.saltedPassword[:], sp) + sc.hasSaltedPassword = true } sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) - clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + clientProof := computeClientProof(sc.saltedPassword[:], sc.authMessage) return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) } @@ -329,7 +361,7 @@ func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { serverSignature := serverFinalMessage[2:] - if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword[:], sc.authMessage)) { return errors.New("invalid SCRAM ServerSignature received from server") } diff --git a/pgconn/auth_scram_test.go b/pgconn/auth_scram_test.go index dcb500f8f..e50592814 100644 --- a/pgconn/auth_scram_test.go +++ b/pgconn/auth_scram_test.go @@ -408,7 +408,7 @@ func TestScramClientRecvServerFinalMessage(t *testing.T) { sc := setup(t) - validSignature := computeServerSignature(sc.saltedPassword, sc.authMessage) + validSignature := computeServerSignature(sc.saltedPassword[:], sc.authMessage) err := sc.recvServerFinalMessage(append([]byte("v="), validSignature...)) require.NoError(t, err) }) diff --git a/pgconn/config.go b/pgconn/config.go index dff550953..939bfe1bb 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -100,6 +100,13 @@ type Config struct { // Valid values: "disable", "prefer", "require". Defaults to "prefer". ChannelBinding string + // ScramDeriveCache, if non-nil, caches SCRAM-SHA-256 PBKDF2 output (the 32-byte salted password) + // under opaque [ScramCacheFingerprint] values computed by pgx during SCRAM. This can reduce per-connection CPU when many + // connections share the same verifier. Entries are invalidated on authentication failure when + // the cached material was used. The same cache instance is shared when [Config.Copy] copies + // the interface value (typical for connection pools). + ScramDeriveCache ScramDeriveCache + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } diff --git a/pgconn/scram_cache.go b/pgconn/scram_cache.go new file mode 100644 index 000000000..9497f986c --- /dev/null +++ b/pgconn/scram_cache.go @@ -0,0 +1,76 @@ +package pgconn + +import ( + "crypto/sha256" + "encoding/binary" + "sync" +) + +// ScramCacheFingerprint is an opaque cache entry identifier for [ScramDeriveCache]. +// pgx computes it while authenticating; implementations must use only the values +// passed to Get, Put, and Delete (do not construct fingerprints yourself). +type ScramCacheFingerprint [32]byte + +// ScramSaltedPassword is the 32-byte PBKDF2 output cached by [ScramDeriveCache]. +type ScramSaltedPassword [32]byte + +func (sc *scramClient) deriveFingerprint() ScramCacheFingerprint { + h := sha256.New() + h.Write([]byte(sc.password)) + h.Write([]byte{0}) + h.Write(sc.salt) + h.Write([]byte{0}) + var iter [8]byte + binary.BigEndian.PutUint64(iter[:], sc.iterations) + h.Write(iter[:]) + var fp ScramCacheFingerprint + h.Sum(fp[:0]) + return fp +} + +// ScramDeriveCache stores the 32-byte SCRAM salted password (PBKDF2 output) so new connections +// can skip PBKDF2 when the server verifier is unchanged. +// +// Method arguments never include passwords or raw salts—only opaque [ScramCacheFingerprint] +// values produced by pgx during SCRAM. +type ScramDeriveCache interface { + Get(ScramCacheFingerprint) (ScramSaltedPassword, bool) + Put(ScramCacheFingerprint, ScramSaltedPassword) + Delete(ScramCacheFingerprint) +} + +// SimpleScramDeriveCache is a small mutex-backed map implementation of [ScramDeriveCache] +// for tests and applications that do not need LRU eviction. +type SimpleScramDeriveCache struct { + mu sync.Mutex + m map[ScramCacheFingerprint]ScramSaltedPassword +} + +// NewSimpleScramDeriveCache returns an empty [SimpleScramDeriveCache]. +func NewSimpleScramDeriveCache() *SimpleScramDeriveCache { + return &SimpleScramDeriveCache{ + m: make(map[ScramCacheFingerprint]ScramSaltedPassword), + } +} + +// Get implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Get(fp ScramCacheFingerprint) (ScramSaltedPassword, bool) { + s.mu.Lock() + defer s.mu.Unlock() + v, ok := s.m[fp] + return v, ok +} + +// Put implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Put(fp ScramCacheFingerprint, sp ScramSaltedPassword) { + s.mu.Lock() + defer s.mu.Unlock() + s.m[fp] = sp +} + +// Delete implements [ScramDeriveCache]. +func (s *SimpleScramDeriveCache) Delete(fp ScramCacheFingerprint) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.m, fp) +} diff --git a/pgconn/scram_cache_test.go b/pgconn/scram_cache_test.go new file mode 100644 index 000000000..ecef25879 --- /dev/null +++ b/pgconn/scram_cache_test.go @@ -0,0 +1,59 @@ +package pgconn + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testScramClient(password string, salt []byte, iterations uint64) *scramClient { + return &scramClient{ + password: password, + salt: salt, + iterations: iterations, + } +} + +func TestScramDeriveFingerprintStable(t *testing.T) { + t.Parallel() + sc := testScramClient("pw", []byte{1, 2, 3}, 4096) + fp1 := sc.deriveFingerprint() + fp2 := sc.deriveFingerprint() + assert.Equal(t, fp1, fp2) +} + +func TestScramDeriveFingerprintDistinct(t *testing.T) { + t.Parallel() + a := testScramClient("a", []byte{1}, 1).deriveFingerprint() + b := testScramClient("b", []byte{1}, 1).deriveFingerprint() + assert.NotEqual(t, a, b) +} + +func TestSimpleScramDeriveCache(t *testing.T) { + t.Parallel() + + c := NewSimpleScramDeriveCache() + fp := testScramClient("pw", []byte{1, 2, 3}, 4096).deriveFingerprint() + + var sp ScramSaltedPassword + for i := range sp { + sp[i] = byte(i) + } + + _, ok := c.Get(fp) + assert.False(t, ok) + + c.Put(fp, sp) + got, ok := c.Get(fp) + require.True(t, ok) + assert.Equal(t, sp, got) + got[0] ^= 0xff + got2, ok := c.Get(fp) + require.True(t, ok) + assert.Equal(t, sp, got2) + + c.Delete(fp) + _, ok = c.Get(fp) + assert.False(t, ok) +}