Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions pgconn/auth_scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
return err
}

if cache := c.config.ScramDeriveCache; cache != nil && sc.password != "" {
if sp, ok := cache.Get(sc.deriveFingerprint()); ok {
sc.saltedPassword = sp
sc.hasSaltedPassword = true
sc.saltedPasswordFromCache = true
}
}

// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
Expand All @@ -106,9 +114,25 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
// Receive server-final-message payload in an AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
scramInvalidateDeriveCache(c, sc)
return err
}
err = sc.recvServerFinalMessage(saslFinal.Data)
if err != nil {
scramInvalidateDeriveCache(c, sc)
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
if cache := c.config.ScramDeriveCache; cache != nil && !sc.saltedPasswordFromCache && sc.hasSaltedPassword {
cache.Put(sc.deriveFingerprint(), sc.saltedPassword)
}
return nil
}

func scramInvalidateDeriveCache(c *PgConn, sc *scramClient) {
if c.config.ScramDeriveCache == nil || !sc.saltedPasswordFromCache {
return
}
c.config.ScramDeriveCache.Delete(sc.deriveFingerprint())
}

func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
Expand Down Expand Up @@ -172,10 +196,12 @@ type scramClient struct {
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
iterations uint64

saltedPassword []byte
authMessage []byte
saltedPassword ScramSaltedPassword
hasSaltedPassword bool
saltedPasswordFromCache bool
authMessage []byte
}

func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
Expand Down Expand Up @@ -279,8 +305,8 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
}

sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
sc.iterations, err = strconv.ParseUint(string(iterationsStr), 10, 64)
if err != nil || sc.iterations == 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
}

Expand Down Expand Up @@ -310,14 +336,20 @@ func (sc *scramClient) clientFinalMessage() string {
channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput)
clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce)

var err error
sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32)
if err != nil {
panic(err) // This should never happen.
if !sc.hasSaltedPassword {
sp, err := pbkdf2.Key(sha256.New, sc.password, sc.salt, int(sc.iterations), 32)
if err != nil {
panic(err) // This should never happen.
}
if len(sp) != 32 {
panic("unexpected PBKDF2 output length")
}
copy(sc.saltedPassword[:], sp)
sc.hasSaltedPassword = true
}
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))

clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
clientProof := computeClientProof(sc.saltedPassword[:], sc.authMessage)

return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}
Expand All @@ -329,7 +361,7 @@ func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {

serverSignature := serverFinalMessage[2:]

if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword[:], sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}

Expand Down
2 changes: 1 addition & 1 deletion pgconn/auth_scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func TestScramClientRecvServerFinalMessage(t *testing.T) {

sc := setup(t)

validSignature := computeServerSignature(sc.saltedPassword, sc.authMessage)
validSignature := computeServerSignature(sc.saltedPassword[:], sc.authMessage)
err := sc.recvServerFinalMessage(append([]byte("v="), validSignature...))
require.NoError(t, err)
})
Expand Down
7 changes: 7 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ type Config struct {
// Valid values: "disable", "prefer", "require". Defaults to "prefer".
ChannelBinding string

// ScramDeriveCache, if non-nil, caches SCRAM-SHA-256 PBKDF2 output (the 32-byte salted password)
// under opaque [ScramCacheFingerprint] values computed by pgx during SCRAM. This can reduce per-connection CPU when many
// connections share the same verifier. Entries are invalidated on authentication failure when
// the cached material was used. The same cache instance is shared when [Config.Copy] copies
// the interface value (typical for connection pools).
ScramDeriveCache ScramDeriveCache

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down
76 changes: 76 additions & 0 deletions pgconn/scram_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package pgconn

import (
"crypto/sha256"
"encoding/binary"
"sync"
)

// ScramCacheFingerprint is an opaque cache entry identifier for [ScramDeriveCache].
// pgx computes it while authenticating; implementations must use only the values
// passed to Get, Put, and Delete (do not construct fingerprints yourself).
type ScramCacheFingerprint [32]byte

// ScramSaltedPassword is the 32-byte PBKDF2 output cached by [ScramDeriveCache].
type ScramSaltedPassword [32]byte

func (sc *scramClient) deriveFingerprint() ScramCacheFingerprint {
h := sha256.New()
h.Write([]byte(sc.password))
h.Write([]byte{0})
h.Write(sc.salt)
h.Write([]byte{0})
var iter [8]byte
binary.BigEndian.PutUint64(iter[:], sc.iterations)
h.Write(iter[:])
var fp ScramCacheFingerprint
h.Sum(fp[:0])
return fp
}

// ScramDeriveCache stores the 32-byte SCRAM salted password (PBKDF2 output) so new connections
// can skip PBKDF2 when the server verifier is unchanged.
//
// Method arguments never include passwords or raw salts—only opaque [ScramCacheFingerprint]
// values produced by pgx during SCRAM.
type ScramDeriveCache interface {
Get(ScramCacheFingerprint) (ScramSaltedPassword, bool)
Put(ScramCacheFingerprint, ScramSaltedPassword)
Delete(ScramCacheFingerprint)
}

// SimpleScramDeriveCache is a small mutex-backed map implementation of [ScramDeriveCache]
// for tests and applications that do not need LRU eviction.
type SimpleScramDeriveCache struct {
mu sync.Mutex
m map[ScramCacheFingerprint]ScramSaltedPassword
}

// NewSimpleScramDeriveCache returns an empty [SimpleScramDeriveCache].
func NewSimpleScramDeriveCache() *SimpleScramDeriveCache {
return &SimpleScramDeriveCache{
m: make(map[ScramCacheFingerprint]ScramSaltedPassword),
}
}

// Get implements [ScramDeriveCache].
func (s *SimpleScramDeriveCache) Get(fp ScramCacheFingerprint) (ScramSaltedPassword, bool) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.m[fp]
return v, ok
}

// Put implements [ScramDeriveCache].
func (s *SimpleScramDeriveCache) Put(fp ScramCacheFingerprint, sp ScramSaltedPassword) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[fp] = sp
}

// Delete implements [ScramDeriveCache].
func (s *SimpleScramDeriveCache) Delete(fp ScramCacheFingerprint) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.m, fp)
}
59 changes: 59 additions & 0 deletions pgconn/scram_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package pgconn

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func testScramClient(password string, salt []byte, iterations uint64) *scramClient {
return &scramClient{
password: password,
salt: salt,
iterations: iterations,
}
}

func TestScramDeriveFingerprintStable(t *testing.T) {
t.Parallel()
sc := testScramClient("pw", []byte{1, 2, 3}, 4096)
fp1 := sc.deriveFingerprint()
fp2 := sc.deriveFingerprint()
assert.Equal(t, fp1, fp2)
}

func TestScramDeriveFingerprintDistinct(t *testing.T) {
t.Parallel()
a := testScramClient("a", []byte{1}, 1).deriveFingerprint()
b := testScramClient("b", []byte{1}, 1).deriveFingerprint()
assert.NotEqual(t, a, b)
}

func TestSimpleScramDeriveCache(t *testing.T) {
t.Parallel()

c := NewSimpleScramDeriveCache()
fp := testScramClient("pw", []byte{1, 2, 3}, 4096).deriveFingerprint()

var sp ScramSaltedPassword
for i := range sp {
sp[i] = byte(i)
}

_, ok := c.Get(fp)
assert.False(t, ok)

c.Put(fp, sp)
got, ok := c.Get(fp)
require.True(t, ok)
assert.Equal(t, sp, got)
got[0] ^= 0xff
got2, ok := c.Get(fp)
require.True(t, ok)
assert.Equal(t, sp, got2)

c.Delete(fp)
_, ok = c.Get(fp)
assert.False(t, ok)
}
Loading