Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion memcache/memcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ type Client struct {
// If zero, DefaultTimeout is used.
Timeout time.Duration

// DialTimeout specifies the timeout for establishing new connections,
// including the TLS handshake if TLS is enabled.
// If zero, Timeout is used (for backward compatibility).
DialTimeout time.Duration

// MaxIdleConns specifies the maximum number of idle connections that will
// be maintained per address. If less than one, DefaultMaxIdleConns will be
// used.
Expand Down Expand Up @@ -237,6 +242,14 @@ func (c *Client) netTimeout() time.Duration {
return DefaultTimeout
}

func (c *Client) dialTimeout() time.Duration {
if c.DialTimeout != 0 {
return c.DialTimeout
}
// Fall back to Timeout for backward compatibility
return c.netTimeout()
}

func (c *Client) maxIdleConns() int {
if c.MaxIdleConns > 0 {
return c.MaxIdleConns
Expand Down Expand Up @@ -265,7 +278,7 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) {
nc net.Conn
err error
)
nd := net.Dialer{Timeout: c.netTimeout()}
nd := net.Dialer{Timeout: c.dialTimeout()}
if c.TlsConfig != nil {
nc, err = tls.DialWithDialer(&nd, addr.Network(), addr.String(), c.TlsConfig)
} else {
Expand Down
123 changes: 123 additions & 0 deletions memcache/memcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,126 @@ func TestLocalhostTLS(t *testing.T) {
c.TlsConfig = tlsConfig
testWithClient(t, c)
}

// TestDialTimeoutMethod tests the dialTimeout() method logic
func TestDialTimeoutMethod(t *testing.T) {
tests := []struct {
name string
dialTimeout time.Duration
timeout time.Duration
expected time.Duration
}{
{
name: "DialTimeout set, should use DialTimeout",
dialTimeout: 500 * time.Millisecond,
timeout: 100 * time.Millisecond,
expected: 500 * time.Millisecond,
},
{
name: "DialTimeout zero, should fall back to Timeout",
dialTimeout: 0,
timeout: 200 * time.Millisecond,
expected: 200 * time.Millisecond,
},
{
name: "Both zero, should use DefaultTimeout",
dialTimeout: 0,
timeout: 0,
expected: DefaultTimeout,
},
{
name: "Only DialTimeout set",
dialTimeout: 300 * time.Millisecond,
timeout: 0,
expected: 300 * time.Millisecond,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
DialTimeout: tt.dialTimeout,
Timeout: tt.timeout,
}
got := c.dialTimeout()
if got != tt.expected {
t.Errorf("dialTimeout() = %v, want %v", got, tt.expected)
}
})
}
}

// TestNetTimeoutMethod tests the netTimeout() method remains unchanged
func TestNetTimeoutMethod(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
expected time.Duration
}{
{
name: "Timeout set",
timeout: 50 * time.Millisecond,
expected: 50 * time.Millisecond,
},
{
name: "Timeout zero, should use DefaultTimeout",
timeout: 0,
expected: DefaultTimeout,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
Timeout: tt.timeout,
}
got := c.netTimeout()
if got != tt.expected {
t.Errorf("netTimeout() = %v, want %v", got, tt.expected)
}
})
}
}

// TestBackwardCompatibility verifies old behavior when DialTimeout is not set
func TestBackwardCompatibility(t *testing.T) {
c := &Client{
Timeout: 250 * time.Millisecond,
// DialTimeout not set (zero value)
}

// dialTimeout() should fall back to netTimeout() which returns Timeout
dialTimeout := c.dialTimeout()
netTimeout := c.netTimeout()

if dialTimeout != netTimeout {
t.Errorf("Backward compatibility broken: dialTimeout=%v, netTimeout=%v", dialTimeout, netTimeout)
}

if dialTimeout != 250*time.Millisecond {
t.Errorf("Expected dialTimeout to be 250ms (from Timeout), got %v", dialTimeout)
}
}

// TestSeparateTimeouts verifies DialTimeout and Timeout can be set independently
func TestSeparateTimeouts(t *testing.T) {
c := &Client{
DialTimeout: 500 * time.Millisecond,
Timeout: 50 * time.Millisecond,
}

dialTimeout := c.dialTimeout()
netTimeout := c.netTimeout()

if dialTimeout != 500*time.Millisecond {
t.Errorf("Expected dialTimeout to be 500ms, got %v", dialTimeout)
}

if netTimeout != 50*time.Millisecond {
t.Errorf("Expected netTimeout to be 50ms, got %v", netTimeout)
}

if dialTimeout == netTimeout {
t.Error("DialTimeout and Timeout should be independent")
}
}