diff --git a/konnectivity-https-proxy/cmd.go b/konnectivity-https-proxy/cmd.go index dbdb6248529..cb0f1633207 100644 --- a/konnectivity-https-proxy/cmd.go +++ b/konnectivity-https-proxy/cmd.go @@ -1,6 +1,7 @@ package konnectivityhttpsproxy import ( + "context" "crypto/tls" "encoding/base64" "fmt" @@ -151,7 +152,6 @@ func NewStartCommand() *cobra.Command { l.V(4).Info("Should proxy", "url", u) return u, nil }, - Dial: konnectivityDialer.Dial, DialContext: konnectivityDialer.DialContext, } if httpsProxyURL != "" { @@ -175,12 +175,13 @@ func NewStartCommand() *cobra.Command { } type dialFunc func(network, addr string) (net.Conn, error) +type dialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error) type dialRequestFunc func(req *http.Request, network, addr string) (net.Conn, error) -func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialFunc { - // NOTE: the function signature is determined by the goproxy library, it requires the deprecated version - // nolint:staticcheck - return httpProxy.Tr.Dial +func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialContextFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return httpProxy.Tr.DialContext(ctx, network, addr) + } } func dialThroughProxyFunc(httpProxy *goproxy.ProxyHttpServer, proxyURL string, proxyURLUser *url.Userinfo) dialFunc { @@ -219,14 +220,14 @@ func addBasicAuthHeader(proxyUser *url.Userinfo) func(req *http.Request) { } } -func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialFunc, dialThroughProxy dialFunc) dialRequestFunc { +func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialContextFunc, dialThroughProxy dialFunc) dialRequestFunc { return func(req *http.Request, network, addr string) (net.Conn, error) { shouldDialDirectly, err := shouldDialDirect(req.URL) if err != nil { return nil, err } if shouldDialDirectly { - return dialDirectly(network, addr) + return dialDirectly(req.Context(), network, addr) } return dialThroughProxy(network, addr) } diff --git a/konnectivity-https-proxy/cmd_test.go b/konnectivity-https-proxy/cmd_test.go index 9f1c8f6c696..103affc8a01 100644 --- a/konnectivity-https-proxy/cmd_test.go +++ b/konnectivity-https-proxy/cmd_test.go @@ -1,8 +1,11 @@ package konnectivityhttpsproxy import ( + "context" "encoding/base64" + "errors" "fmt" + "net" "net/http" "net/url" "strings" @@ -10,6 +13,7 @@ import ( . "github.com/onsi/gomega" + "github.com/elazarl/goproxy" "golang.org/x/net/http/httpproxy" ) @@ -101,6 +105,133 @@ func TestShouldDialDirectFunc(t *testing.T) { } } +func TestDialDirectFunc(t *testing.T) { + dialErr := errors.New("dial failed") + + tests := []struct { + name string + dialCtx func(ctx context.Context, network, addr string) (net.Conn, error) + addr func(t *testing.T) string + expectErr error + }{ + { + name: "When dialing with a valid listener it should connect successfully", + dialCtx: (&net.Dialer{}).DialContext, + addr: func(t *testing.T) string { + listener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + t.Cleanup(func() { listener.Close() }) + return listener.Addr().String() + }, + }, + { + name: "When the transport DialContext fails it should return an error", + dialCtx: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, dialErr + }, + addr: func(t *testing.T) string { return "127.0.0.1:1" }, + expectErr: dialErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + httpProxy := goproxy.NewProxyHttpServer() + httpProxy.Tr = &http.Transport{ + DialContext: tc.dialCtx, + } + + dialFn := dialDirectFunc(httpProxy) + conn, err := dialFn(t.Context(), "tcp", tc.addr(t)) + + if tc.expectErr != nil { + g.Expect(err).To(MatchError(tc.expectErr)) + g.Expect(conn).To(BeNil()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn).NotTo(BeNil()) + conn.Close() + } + }) + } +} + +func TestConnectDialFunc(t *testing.T) { + lookupErr := errors.New("lookup failed") + + tests := []struct { + name string + shouldDialDirect bool + shouldDialDirectErr error + expectDialDirect bool + expectDialProxy bool + expectErr error + }{ + { + name: "When shouldDialDirect returns true it should dial directly with request context", + shouldDialDirect: true, + expectDialDirect: true, + }, + { + name: "When shouldDialDirect returns false it should dial through proxy", + expectDialProxy: true, + }, + { + name: "When shouldDialDirect returns an error it should propagate the error", + shouldDialDirectErr: lookupErr, + expectErr: lookupErr, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + type contextKey string + reqCtx := context.WithValue(t.Context(), contextKey("test"), "value") + req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, "https://example.com:443", nil) + g.Expect(err).NotTo(HaveOccurred()) + + directCalled := false + proxyCalled := false + var capturedCtx any + + dialDirectly := func(ctx context.Context, network, addr string) (net.Conn, error) { + directCalled = true + capturedCtx = ctx + return nil, nil + } + dialThroughProxy := func(network, addr string) (net.Conn, error) { + proxyCalled = true + return nil, nil + } + shouldDialDirect := func(u *url.URL) (bool, error) { + return tc.shouldDialDirect, tc.shouldDialDirectErr + } + + f := connectDialFunc(shouldDialDirect, dialDirectly, dialThroughProxy) + conn, err := f(req, "tcp", "example.com:443") + + if tc.expectErr != nil { + g.Expect(err).To(MatchError(tc.expectErr)) + g.Expect(conn).To(BeNil()) + } else { + g.Expect(err).NotTo(HaveOccurred()) + } + g.Expect(directCalled).To(Equal(tc.expectDialDirect)) + g.Expect(proxyCalled).To(Equal(tc.expectDialProxy)) + if tc.expectDialDirect { + g.Expect(capturedCtx).To(Equal(reqCtx)) + g.Expect(capturedCtx.(context.Context).Value(contextKey("test"))).To(Equal("value")) + } + }) + } +} + func TestAddBasicAuthHeader(t *testing.T) { userInfo := url.UserPassword("user", "password") tests := []struct {