From 47a4b0b6e863aa33c3dbd6e4bf81fbcb4e6d5c67 Mon Sep 17 00:00:00 2001 From: Poornima Singour Date: Mon, 25 May 2026 16:22:42 +0530 Subject: [PATCH 1/3] fix(konnectivity): replace deprecated Dial with DialContext Replace the deprecated `net.Dial` with `DialContext` in konnectivity-https-proxy and propagate the original request context through the dial chain instead of using context.Background(). Introduce dialContextFunc type to carry context from connectDialFunc (which has access to req.Context()) through to dialDirectFunc. Add unit tests for dialDirectFunc covering success and failure scenarios. Co-Authored-By: Poornima Singour --- konnectivity-https-proxy/cmd.go | 15 ++-- konnectivity-https-proxy/cmd_test.go | 114 +++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 7 deletions(-) diff --git a/konnectivity-https-proxy/cmd.go b/konnectivity-https-proxy/cmd.go index dbdb6248529..cb0f1633207 100644 --- a/konnectivity-https-proxy/cmd.go +++ b/konnectivity-https-proxy/cmd.go @@ -1,6 +1,7 @@ package konnectivityhttpsproxy import ( + "context" "crypto/tls" "encoding/base64" "fmt" @@ -151,7 +152,6 @@ func NewStartCommand() *cobra.Command { l.V(4).Info("Should proxy", "url", u) return u, nil }, - Dial: konnectivityDialer.Dial, DialContext: konnectivityDialer.DialContext, } if httpsProxyURL != "" { @@ -175,12 +175,13 @@ func NewStartCommand() *cobra.Command { } type dialFunc func(network, addr string) (net.Conn, error) +type dialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error) type dialRequestFunc func(req *http.Request, network, addr string) (net.Conn, error) -func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialFunc { - // NOTE: the function signature is determined by the goproxy library, it requires the deprecated version - // nolint:staticcheck - return httpProxy.Tr.Dial +func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialContextFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return httpProxy.Tr.DialContext(ctx, network, addr) + } } func dialThroughProxyFunc(httpProxy *goproxy.ProxyHttpServer, proxyURL string, proxyURLUser *url.Userinfo) dialFunc { @@ -219,14 +220,14 @@ func addBasicAuthHeader(proxyUser *url.Userinfo) func(req *http.Request) { } } -func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialFunc, dialThroughProxy dialFunc) dialRequestFunc { +func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialContextFunc, dialThroughProxy dialFunc) dialRequestFunc { return func(req *http.Request, network, addr string) (net.Conn, error) { shouldDialDirectly, err := shouldDialDirect(req.URL) if err != nil { return nil, err } if shouldDialDirectly { - return dialDirectly(network, addr) + return dialDirectly(req.Context(), network, addr) } return dialThroughProxy(network, addr) } diff --git a/konnectivity-https-proxy/cmd_test.go b/konnectivity-https-proxy/cmd_test.go index 9f1c8f6c696..23973463f71 100644 --- a/konnectivity-https-proxy/cmd_test.go +++ b/konnectivity-https-proxy/cmd_test.go @@ -1,8 +1,11 @@ package konnectivityhttpsproxy import ( + "context" "encoding/base64" + "errors" "fmt" + "net" "net/http" "net/url" "strings" @@ -10,6 +13,7 @@ import ( . "github.com/onsi/gomega" + "github.com/elazarl/goproxy" "golang.org/x/net/http/httpproxy" ) @@ -101,6 +105,116 @@ func TestShouldDialDirectFunc(t *testing.T) { } } +func TestDialDirectFunc(t *testing.T) { + t.Run("When dialing with a valid listener it should connect successfully", func(t *testing.T) { + g := NewGomegaWithT(t) + + listener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + g.Expect(err).NotTo(HaveOccurred()) + defer listener.Close() + + httpProxy := goproxy.NewProxyHttpServer() + httpProxy.Tr = &http.Transport{ + DialContext: (&net.Dialer{}).DialContext, + } + + dialFn := dialDirectFunc(httpProxy) + conn, err := dialFn(t.Context(), "tcp", listener.Addr().String()) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn).NotTo(BeNil()) + conn.Close() + }) + + t.Run("When the transport DialContext fails it should return an error", func(t *testing.T) { + g := NewGomegaWithT(t) + + expectedErr := errors.New("dial failed") + httpProxy := goproxy.NewProxyHttpServer() + httpProxy.Tr = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, expectedErr + }, + } + + dialFn := dialDirectFunc(httpProxy) + conn, err := dialFn(t.Context(), "tcp", "127.0.0.1:1") + g.Expect(err).To(MatchError(expectedErr)) + g.Expect(conn).To(BeNil()) + }) +} + +func TestConnectDialFunc(t *testing.T) { + lookupErr := errors.New("lookup failed") + + tests := []struct { + name string + shouldDialDirect bool + shouldDialDirectErr error + expectDialDirect bool + expectDialProxy bool + expectErr error + }{ + { + name: "When shouldDialDirect returns true it should dial directly with request context", + shouldDialDirect: true, + expectDialDirect: true, + }, + { + name: "When shouldDialDirect returns false it should dial through proxy", + expectDialProxy: true, + }, + { + name: "When shouldDialDirect returns an error it should propagate the error", + shouldDialDirectErr: lookupErr, + expectErr: lookupErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + type contextKey string + reqCtx := context.WithValue(t.Context(), contextKey("test"), "value") + req, err := http.NewRequestWithContext(reqCtx, "CONNECT", "https://example.com:443", nil) + g.Expect(err).NotTo(HaveOccurred()) + + directCalled := false + proxyCalled := false + var capturedCtx context.Context + + dialDirectly := func(ctx context.Context, network, addr string) (net.Conn, error) { + directCalled = true + capturedCtx = ctx + return nil, nil + } + dialThroughProxy := func(network, addr string) (net.Conn, error) { + proxyCalled = true + return nil, nil + } + shouldDialDirect := func(u *url.URL) (bool, error) { + return tc.shouldDialDirect, tc.shouldDialDirectErr + } + + f := connectDialFunc(shouldDialDirect, dialDirectly, dialThroughProxy) + conn, err := f(req, "tcp", "example.com:443") + + if tc.expectErr != nil { + g.Expect(err).To(MatchError(tc.expectErr)) + g.Expect(conn).To(BeNil()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + g.Expect(directCalled).To(Equal(tc.expectDialDirect)) + g.Expect(proxyCalled).To(Equal(tc.expectDialProxy)) + if tc.expectDialDirect { + g.Expect(capturedCtx).To(Equal(reqCtx)) + g.Expect(capturedCtx.Value(contextKey("test"))).To(Equal("value")) + } + }) + } +} + func TestAddBasicAuthHeader(t *testing.T) { userInfo := url.UserPassword("user", "password") tests := []struct { From d716f0b52e50a5bf377b8b7c68c1692f4cd5b420 Mon Sep 17 00:00:00 2001 From: Poornima Singour Date: Wed, 3 Jun 2026 20:59:18 +0530 Subject: [PATCH 2/3] fix(konnectivity): fix gci lint alignment in cmd_test.go Co-Authored-By: Claude Opus 4.6 (1M context) --- konnectivity-https-proxy/cmd_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/konnectivity-https-proxy/cmd_test.go b/konnectivity-https-proxy/cmd_test.go index 23973463f71..adbc1b93c4f 100644 --- a/konnectivity-https-proxy/cmd_test.go +++ b/konnectivity-https-proxy/cmd_test.go @@ -147,12 +147,12 @@ func TestConnectDialFunc(t *testing.T) { lookupErr := errors.New("lookup failed") tests := []struct { - name string - shouldDialDirect bool + name string + shouldDialDirect bool shouldDialDirectErr error - expectDialDirect bool - expectDialProxy bool - expectErr error + expectDialDirect bool + expectDialProxy bool + expectErr error }{ { name: "When shouldDialDirect returns true it should dial directly with request context", From 3cd83633980fe8e7e6b67fab2e858762df687c30 Mon Sep 17 00:00:00 2001 From: Poornima Singour Date: Thu, 4 Jun 2026 15:53:27 +0530 Subject: [PATCH 3/3] refactor(konnectivity): convert TestDialDirectFunc to table-driven test Co-Authored-By: Claude Opus 4.6 (1M context) --- konnectivity-https-proxy/cmd_test.go | 85 +++++++++++++++++----------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/konnectivity-https-proxy/cmd_test.go b/konnectivity-https-proxy/cmd_test.go index adbc1b93c4f..103affc8a01 100644 --- a/konnectivity-https-proxy/cmd_test.go +++ b/konnectivity-https-proxy/cmd_test.go @@ -106,41 +106,58 @@ func TestShouldDialDirectFunc(t *testing.T) { } func TestDialDirectFunc(t *testing.T) { - t.Run("When dialing with a valid listener it should connect successfully", func(t *testing.T) { - g := NewGomegaWithT(t) + dialErr := errors.New("dial failed") - listener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") - g.Expect(err).NotTo(HaveOccurred()) - defer listener.Close() + tests := []struct { + name string + dialCtx func(ctx context.Context, network, addr string) (net.Conn, error) + addr func(t *testing.T) string + expectErr error + }{ + { + name: "When dialing with a valid listener it should connect successfully", + dialCtx: (&net.Dialer{}).DialContext, + addr: func(t *testing.T) string { + listener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + t.Cleanup(func() { listener.Close() }) + return listener.Addr().String() + }, + }, + { + name: "When the transport DialContext fails it should return an error", + dialCtx: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, dialErr + }, + addr: func(t *testing.T) string { return "127.0.0.1:1" }, + expectErr: dialErr, + }, + } - httpProxy := goproxy.NewProxyHttpServer() - httpProxy.Tr = &http.Transport{ - DialContext: (&net.Dialer{}).DialContext, - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) - dialFn := dialDirectFunc(httpProxy) - conn, err := dialFn(t.Context(), "tcp", listener.Addr().String()) - g.Expect(err).NotTo(HaveOccurred()) - g.Expect(conn).NotTo(BeNil()) - conn.Close() - }) - - t.Run("When the transport DialContext fails it should return an error", func(t *testing.T) { - g := NewGomegaWithT(t) - - expectedErr := errors.New("dial failed") - httpProxy := goproxy.NewProxyHttpServer() - httpProxy.Tr = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, expectedErr - }, - } + httpProxy := goproxy.NewProxyHttpServer() + httpProxy.Tr = &http.Transport{ + DialContext: tc.dialCtx, + } - dialFn := dialDirectFunc(httpProxy) - conn, err := dialFn(t.Context(), "tcp", "127.0.0.1:1") - g.Expect(err).To(MatchError(expectedErr)) - g.Expect(conn).To(BeNil()) - }) + dialFn := dialDirectFunc(httpProxy) + conn, err := dialFn(t.Context(), "tcp", tc.addr(t)) + + if tc.expectErr != nil { + g.Expect(err).To(MatchError(tc.expectErr)) + g.Expect(conn).To(BeNil()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn).NotTo(BeNil()) + conn.Close() + } + }) + } } func TestConnectDialFunc(t *testing.T) { @@ -176,12 +193,12 @@ func TestConnectDialFunc(t *testing.T) { type contextKey string reqCtx := context.WithValue(t.Context(), contextKey("test"), "value") - req, err := http.NewRequestWithContext(reqCtx, "CONNECT", "https://example.com:443", nil) + req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, "https://example.com:443", nil) g.Expect(err).NotTo(HaveOccurred()) directCalled := false proxyCalled := false - var capturedCtx context.Context + var capturedCtx any dialDirectly := func(ctx context.Context, network, addr string) (net.Conn, error) { directCalled = true @@ -209,7 +226,7 @@ func TestConnectDialFunc(t *testing.T) { g.Expect(proxyCalled).To(Equal(tc.expectDialProxy)) if tc.expectDialDirect { g.Expect(capturedCtx).To(Equal(reqCtx)) - g.Expect(capturedCtx.Value(contextKey("test"))).To(Equal("value")) + g.Expect(capturedCtx.(context.Context).Value(contextKey("test"))).To(Equal("value")) } }) }