Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions konnectivity-https-proxy/cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package konnectivityhttpsproxy

import (
"context"
"crypto/tls"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked: removing Dial is safe — Go's http.Transport prefers DialContext when both are set, and goproxy's internal dial() also checks Tr.DialContext first. Good cleanup.

"encoding/base64"
"fmt"
Expand Down Expand Up @@ -151,7 +152,6 @@ func NewStartCommand() *cobra.Command {
l.V(4).Info("Should proxy", "url", u)
return u, nil
},
Dial: konnectivityDialer.Dial,
DialContext: konnectivityDialer.DialContext,
}
if httpsProxyURL != "" {
Expand All @@ -175,12 +175,13 @@ func NewStartCommand() *cobra.Command {
}

type dialFunc func(network, addr string) (net.Conn, error)
type dialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)
type dialRequestFunc func(req *http.Request, network, addr string) (net.Conn, error)

func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialFunc {
// NOTE: the function signature is determined by the goproxy library, it requires the deprecated version
// nolint:staticcheck
return httpProxy.Tr.Dial
func dialDirectFunc(httpProxy *goproxy.ProxyHttpServer) dialContextFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return httpProxy.Tr.DialContext(ctx, network, addr)
}
}

func dialThroughProxyFunc(httpProxy *goproxy.ProxyHttpServer, proxyURL string, proxyURLUser *url.Userinfo) dialFunc {
Expand Down Expand Up @@ -219,14 +220,14 @@ func addBasicAuthHeader(proxyUser *url.Userinfo) func(req *http.Request) {
}
}

func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialFunc, dialThroughProxy dialFunc) dialRequestFunc {
func connectDialFunc(shouldDialDirect func(*url.URL) (bool, error), dialDirectly dialContextFunc, dialThroughProxy dialFunc) dialRequestFunc {
return func(req *http.Request, network, addr string) (net.Conn, error) {
shouldDialDirectly, err := shouldDialDirect(req.URL)
if err != nil {
return nil, err
}
if shouldDialDirectly {
return dialDirectly(network, addr)
return dialDirectly(req.Context(), network, addr)
}
return dialThroughProxy(network, addr)
}
Expand Down
131 changes: 131 additions & 0 deletions konnectivity-https-proxy/cmd_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package konnectivityhttpsproxy

import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"testing"

. "github.com/onsi/gomega"

"github.com/elazarl/goproxy"
"golang.org/x/net/http/httpproxy"
)

Expand Down Expand Up @@ -101,6 +105,133 @@ func TestShouldDialDirectFunc(t *testing.T) {
}
}

func TestDialDirectFunc(t *testing.T) {
dialErr := errors.New("dial failed")

tests := []struct {
name string
dialCtx func(ctx context.Context, network, addr string) (net.Conn, error)
addr func(t *testing.T) string
expectErr error
}{
{
name: "When dialing with a valid listener it should connect successfully",
dialCtx: (&net.Dialer{}).DialContext,
addr: func(t *testing.T) string {
listener, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
t.Cleanup(func() { listener.Close() })
return listener.Addr().String()
},
},
{
name: "When the transport DialContext fails it should return an error",
dialCtx: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, dialErr
},
addr: func(t *testing.T) string { return "127.0.0.1:1" },
expectErr: dialErr,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
g := NewGomegaWithT(t)

httpProxy := goproxy.NewProxyHttpServer()
httpProxy.Tr = &http.Transport{
DialContext: tc.dialCtx,
}

dialFn := dialDirectFunc(httpProxy)
conn, err := dialFn(t.Context(), "tcp", tc.addr(t))

if tc.expectErr != nil {
g.Expect(err).To(MatchError(tc.expectErr))
g.Expect(conn).To(BeNil())
} else {
g.Expect(err).NotTo(HaveOccurred())
g.Expect(conn).NotTo(BeNil())
conn.Close()
}
})
}
}

func TestConnectDialFunc(t *testing.T) {
lookupErr := errors.New("lookup failed")

tests := []struct {
name string
shouldDialDirect bool
shouldDialDirectErr error
expectDialDirect bool
expectDialProxy bool
expectErr error
}{
{
name: "When shouldDialDirect returns true it should dial directly with request context",
shouldDialDirect: true,
expectDialDirect: true,
},
{
name: "When shouldDialDirect returns false it should dial through proxy",
expectDialProxy: true,
},
{
name: "When shouldDialDirect returns an error it should propagate the error",
shouldDialDirectErr: lookupErr,
expectErr: lookupErr,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
g := NewGomegaWithT(t)

type contextKey string
reqCtx := context.WithValue(t.Context(), contextKey("test"), "value")
req, err := http.NewRequestWithContext(reqCtx, http.MethodConnect, "https://example.com:443", nil)
g.Expect(err).NotTo(HaveOccurred())

directCalled := false
proxyCalled := false
var capturedCtx any

dialDirectly := func(ctx context.Context, network, addr string) (net.Conn, error) {
directCalled = true
capturedCtx = ctx
return nil, nil
}
dialThroughProxy := func(network, addr string) (net.Conn, error) {
proxyCalled = true
return nil, nil
}
shouldDialDirect := func(u *url.URL) (bool, error) {
return tc.shouldDialDirect, tc.shouldDialDirectErr
}

f := connectDialFunc(shouldDialDirect, dialDirectly, dialThroughProxy)
conn, err := f(req, "tcp", "example.com:443")

if tc.expectErr != nil {
g.Expect(err).To(MatchError(tc.expectErr))
g.Expect(conn).To(BeNil())
} else {
g.Expect(err).NotTo(HaveOccurred())
}
g.Expect(directCalled).To(Equal(tc.expectDialDirect))
g.Expect(proxyCalled).To(Equal(tc.expectDialProxy))
if tc.expectDialDirect {
g.Expect(capturedCtx).To(Equal(reqCtx))
g.Expect(capturedCtx.(context.Context).Value(contextKey("test"))).To(Equal("value"))
}
})
}
}

func TestAddBasicAuthHeader(t *testing.T) {
userInfo := url.UserPassword("user", "password")
tests := []struct {
Expand Down