From cc542e9aac02230af4ec16486aee9d9cc3412c31 Mon Sep 17 00:00:00 2001 From: Ethan Lin Date: Tue, 17 Feb 2026 17:27:23 -0800 Subject: [PATCH] [COR-21902] Add dial timeout config for client --- memcache/memcache.go | 15 ++++- memcache/memcache_test.go | 123 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/memcache/memcache.go b/memcache/memcache.go index 5d71b827..0b855dc3 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -136,6 +136,11 @@ type Client struct { // If zero, DefaultTimeout is used. Timeout time.Duration + // DialTimeout specifies the timeout for establishing new connections, + // including the TLS handshake if TLS is enabled. + // If zero, Timeout is used (for backward compatibility). + DialTimeout time.Duration + // MaxIdleConns specifies the maximum number of idle connections that will // be maintained per address. If less than one, DefaultMaxIdleConns will be // used. @@ -237,6 +242,14 @@ func (c *Client) netTimeout() time.Duration { return DefaultTimeout } +func (c *Client) dialTimeout() time.Duration { + if c.DialTimeout != 0 { + return c.DialTimeout + } + // Fall back to Timeout for backward compatibility + return c.netTimeout() +} + func (c *Client) maxIdleConns() int { if c.MaxIdleConns > 0 { return c.MaxIdleConns @@ -265,7 +278,7 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { nc net.Conn err error ) - nd := net.Dialer{Timeout: c.netTimeout()} + nd := net.Dialer{Timeout: c.dialTimeout()} if c.TlsConfig != nil { nc, err = tls.DialWithDialer(&nd, addr.Network(), addr.String(), c.TlsConfig) } else { diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index c448d6db..d8b48f01 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -346,3 +346,126 @@ func TestLocalhostTLS(t *testing.T) { c.TlsConfig = tlsConfig testWithClient(t, c) } + +// TestDialTimeoutMethod tests the dialTimeout() method logic +func TestDialTimeoutMethod(t *testing.T) { + tests := []struct { + name string + dialTimeout time.Duration + timeout time.Duration + expected time.Duration + }{ + { + name: "DialTimeout set, should use DialTimeout", + dialTimeout: 500 * time.Millisecond, + timeout: 100 * time.Millisecond, + expected: 500 * time.Millisecond, + }, + { + name: "DialTimeout zero, should fall back to Timeout", + dialTimeout: 0, + timeout: 200 * time.Millisecond, + expected: 200 * time.Millisecond, + }, + { + name: "Both zero, should use DefaultTimeout", + dialTimeout: 0, + timeout: 0, + expected: DefaultTimeout, + }, + { + name: "Only DialTimeout set", + dialTimeout: 300 * time.Millisecond, + timeout: 0, + expected: 300 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + DialTimeout: tt.dialTimeout, + Timeout: tt.timeout, + } + got := c.dialTimeout() + if got != tt.expected { + t.Errorf("dialTimeout() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestNetTimeoutMethod tests the netTimeout() method remains unchanged +func TestNetTimeoutMethod(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + expected time.Duration + }{ + { + name: "Timeout set", + timeout: 50 * time.Millisecond, + expected: 50 * time.Millisecond, + }, + { + name: "Timeout zero, should use DefaultTimeout", + timeout: 0, + expected: DefaultTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + Timeout: tt.timeout, + } + got := c.netTimeout() + if got != tt.expected { + t.Errorf("netTimeout() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestBackwardCompatibility verifies old behavior when DialTimeout is not set +func TestBackwardCompatibility(t *testing.T) { + c := &Client{ + Timeout: 250 * time.Millisecond, + // DialTimeout not set (zero value) + } + + // dialTimeout() should fall back to netTimeout() which returns Timeout + dialTimeout := c.dialTimeout() + netTimeout := c.netTimeout() + + if dialTimeout != netTimeout { + t.Errorf("Backward compatibility broken: dialTimeout=%v, netTimeout=%v", dialTimeout, netTimeout) + } + + if dialTimeout != 250*time.Millisecond { + t.Errorf("Expected dialTimeout to be 250ms (from Timeout), got %v", dialTimeout) + } +} + +// TestSeparateTimeouts verifies DialTimeout and Timeout can be set independently +func TestSeparateTimeouts(t *testing.T) { + c := &Client{ + DialTimeout: 500 * time.Millisecond, + Timeout: 50 * time.Millisecond, + } + + dialTimeout := c.dialTimeout() + netTimeout := c.netTimeout() + + if dialTimeout != 500*time.Millisecond { + t.Errorf("Expected dialTimeout to be 500ms, got %v", dialTimeout) + } + + if netTimeout != 50*time.Millisecond { + t.Errorf("Expected netTimeout to be 50ms, got %v", netTimeout) + } + + if dialTimeout == netTimeout { + t.Error("DialTimeout and Timeout should be independent") + } +}