diff --git a/internal/credentials/gcloud.go b/internal/credentials/gcloud.go index 79913c1..e779096 100644 --- a/internal/credentials/gcloud.go +++ b/internal/credentials/gcloud.go @@ -73,6 +73,11 @@ func (g *GCloudInjector) init() error { // Inject sets the Authorization: Bearer header with a fresh OAuth2 token. // Always overrides — the agent may have a token from a dummy ADC file. func (g *GCloudInjector) Inject(req *http.Request) bool { + if req == nil { + log.Printf("DEFENSIVE_CHECK: GCloudInjector.Inject called with nil request") + return false + } + if err := g.init(); err != nil { log.Printf("ERROR gcloud credential init failed: %v", err) return false diff --git a/internal/credentials/gcloud_test.go b/internal/credentials/gcloud_test.go new file mode 100644 index 0000000..ad13c87 --- /dev/null +++ b/internal/credentials/gcloud_test.go @@ -0,0 +1,67 @@ +package credentials + +import ( + "net/http" + "testing" +) + +func TestGCloudInjector_NilRequest(t *testing.T) { + // Create injector with non-existent path (will fail init, but that's OK for this test) + inj := NewGCloudInjector("/nonexistent/path/to/adc.json") + + // Should handle nil request gracefully + if inj.Inject(nil) { + t.Error("nil request should return false") + } +} + +func TestGCloudInjector_NilHeader(t *testing.T) { + inj := NewGCloudInjector("/nonexistent/path/to/adc.json") + + req := &http.Request{ + Header: nil, + } + + if inj.Inject(req) { + t.Error("request with nil Header should return false") + } +} + +func TestGCloudInjector_InitFailure(t *testing.T) { + inj := NewGCloudInjector("/nonexistent/path/to/adc.json") + + req := &http.Request{ + Header: make(http.Header), + } + + // Should return false due to init failure (file doesn't exist) + if inj.Inject(req) { + t.Error("inject should fail when ADC file doesn't exist") + } + + // Authorization header should not be set + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization should be empty on init failure, got %q", got) + } +} + +func TestGCloudInjectorFromJSON_NilRequest(t *testing.T) { + // Invalid JSON will cause init to fail, but nil check comes first + inj := NewGCloudInjectorFromJSON([]byte("invalid json")) + + if inj.Inject(nil) { + t.Error("nil request should return false") + } +} + +func TestGCloudInjectorFromJSON_NilHeader(t *testing.T) { + inj := NewGCloudInjectorFromJSON([]byte("invalid json")) + + req := &http.Request{ + Header: nil, + } + + if inj.Inject(req) { + t.Error("request with nil Header should return false") + } +} diff --git a/internal/credentials/static.go b/internal/credentials/static.go index 493f673..341d89c 100644 --- a/internal/credentials/static.go +++ b/internal/credentials/static.go @@ -1,9 +1,20 @@ package credentials import ( + "log" "net/http" ) +// validateRequest checks if request is valid for credential injection. +// Returns false if req or req.Header is nil, logging the injector name for debugging. +func validateRequest(req *http.Request, injectorName string) bool { + if req == nil || req.Header == nil { + log.Printf("DEFENSIVE_CHECK: %s.Inject called with nil request or Header", injectorName) + return false + } + return true +} + // HeaderInjector injects a static value into a specific header. // Always overrides any existing value — the agent should never // control which credentials are used. @@ -15,6 +26,9 @@ type HeaderInjector struct { } func (h *HeaderInjector) Inject(req *http.Request) bool { + if !validateRequest(req, "HeaderInjector") { + return false + } req.Header.Set(h.Header, h.Value) return true } @@ -26,6 +40,9 @@ type BearerInjector struct { } func (b *BearerInjector) Inject(req *http.Request) bool { + if !validateRequest(req, "BearerInjector") { + return false + } req.Header.Set("Authorization", "Bearer "+b.Token) return true } @@ -38,6 +55,9 @@ type APIKeyInjector struct { } func (a *APIKeyInjector) Inject(req *http.Request) bool { + if !validateRequest(req, "APIKeyInjector") { + return false + } req.Header.Set(a.HeaderName, a.Key) return true } diff --git a/internal/credentials/static_test.go b/internal/credentials/static_test.go new file mode 100644 index 0000000..0fb841c --- /dev/null +++ b/internal/credentials/static_test.go @@ -0,0 +1,67 @@ +package credentials + +import ( + "net/http" + "testing" +) + +func TestHeaderInjector_NilRequest(t *testing.T) { + inj := &HeaderInjector{Header: "X-Custom", Value: "test"} + + // Should handle nil request gracefully + if inj.Inject(nil) { + t.Error("nil request should return false") + } +} + +func TestHeaderInjector_NilHeader(t *testing.T) { + inj := &HeaderInjector{Header: "X-Custom", Value: "test"} + + req := &http.Request{ + Header: nil, + } + + if inj.Inject(req) { + t.Error("request with nil Header should return false") + } +} + +func TestBearerInjector_NilRequest(t *testing.T) { + inj := &BearerInjector{Token: "test-token"} + + if inj.Inject(nil) { + t.Error("nil request should return false") + } +} + +func TestBearerInjector_NilHeader(t *testing.T) { + inj := &BearerInjector{Token: "test-token"} + + req := &http.Request{ + Header: nil, + } + + if inj.Inject(req) { + t.Error("request with nil Header should return false") + } +} + +func TestAPIKeyInjector_NilRequest(t *testing.T) { + inj := &APIKeyInjector{HeaderName: "x-api-key", Key: "test-key"} + + if inj.Inject(nil) { + t.Error("nil request should return false") + } +} + +func TestAPIKeyInjector_NilHeader(t *testing.T) { + inj := &APIKeyInjector{HeaderName: "x-api-key", Key: "test-key"} + + req := &http.Request{ + Header: nil, + } + + if inj.Inject(req) { + t.Error("request with nil Header should return false") + } +} diff --git a/internal/credentials/store.go b/internal/credentials/store.go index 3cf5ba3..f97e442 100644 --- a/internal/credentials/store.go +++ b/internal/credentials/store.go @@ -55,6 +55,10 @@ func (s *Store) InjectCredentials(req *http.Request) (bool, bool) { s.mu.RLock() defer s.mu.RUnlock() + if req == nil || req.URL == nil { + return false, false + } + host := req.URL.Host if idx := strings.LastIndex(host, ":"); idx != -1 { host = host[:idx] diff --git a/internal/credentials/store_test.go b/internal/credentials/store_test.go index cf249dc..f724eb4 100644 --- a/internal/credentials/store_test.go +++ b/internal/credentials/store_test.go @@ -169,3 +169,37 @@ func TestStore_InjectCredentials_InjectorFails(t *testing.T) { t.Errorf("Authorization should be empty, got %q", got) } } + +// TestStore_InjectCredentials_NilRequest tests defensive nil checks +func TestStore_InjectCredentials_NilRequest(t *testing.T) { + store := NewStore() + store.AddRoute(Route{ + ExactDomain: "example.com", + Injector: &BearerInjector{Token: "test-token"}, + }) + + // Should handle nil request gracefully + matched, injected := store.InjectCredentials(nil) + if matched || injected { + t.Error("nil request should return (false, false)") + } +} + +func TestStore_InjectCredentials_NilURL(t *testing.T) { + store := NewStore() + store.AddRoute(Route{ + ExactDomain: "example.com", + Injector: &BearerInjector{Token: "test-token"}, + }) + + // Request with nil URL + req := &http.Request{ + URL: nil, + Header: make(http.Header), + } + + matched, injected := store.InjectCredentials(req) + if matched || injected { + t.Error("request with nil URL should return (false, false)") + } +} diff --git a/internal/credentials/token_vending.go b/internal/credentials/token_vending.go index a8107cf..77181a0 100644 --- a/internal/credentials/token_vending.go +++ b/internal/credentials/token_vending.go @@ -8,6 +8,18 @@ import ( "net/http" ) +// errorResponse creates an HTTP error response with plain text content. +func errorResponse(statusCode int, message string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(bytes.NewReader([]byte(message))), + ContentLength: int64(len(message)), + } +} + // TokenVendor intercepts OAuth2 token exchange requests from the agent's // Google Auth library and returns dummy tokens. // @@ -36,6 +48,10 @@ type tokenResponse struct { // IsTokenExchange returns true if the request is an OAuth2 token exchange // to Google's token endpoint. func IsTokenExchange(req *http.Request) bool { + if req == nil || req.URL == nil { + return false + } + host := req.URL.Host if host == "" { host = req.Host @@ -51,6 +67,11 @@ func IsTokenExchange(req *http.Request) bool { // a dummy access token. The real token injection happens later via the // GCloudInjector when the agent makes API calls to *.googleapis.com. func (tv *TokenVendor) HandleTokenExchange(req *http.Request) *http.Response { + if req == nil || req.URL == nil { + log.Printf("DEFENSIVE_CHECK: HandleTokenExchange called with nil request or URL") + return errorResponse(http.StatusBadRequest, "Malformed token exchange request") + } + resp := &tokenResponse{ AccessToken: "paude-proxy-managed", ExpiresIn: 3600, @@ -60,7 +81,7 @@ func (tv *TokenVendor) HandleTokenExchange(req *http.Request) *http.Response { body, err := json.Marshal(resp) if err != nil { log.Printf("ERROR token vendor: marshal response: %v", err) - return nil + return errorResponse(http.StatusInternalServerError, "Internal token vendor error") } log.Printf("TOKEN_VEND host=%s path=%s (returned dummy token, real injection at request time)", req.URL.Host, req.URL.Path) diff --git a/internal/credentials/token_vending_test.go b/internal/credentials/token_vending_test.go new file mode 100644 index 0000000..9fd8ae9 --- /dev/null +++ b/internal/credentials/token_vending_test.go @@ -0,0 +1,103 @@ +package credentials + +import ( + "io" + "net/http" + "net/url" + "testing" +) + +func TestIsTokenExchange_NilRequest(t *testing.T) { + // Should handle nil request gracefully + if IsTokenExchange(nil) { + t.Error("nil request should return false") + } +} + +func TestIsTokenExchange_NilURL(t *testing.T) { + req := &http.Request{ + Method: http.MethodPost, + URL: nil, + } + + if IsTokenExchange(req) { + t.Error("request with nil URL should return false") + } +} + +func TestIsTokenExchange_ValidRequest(t *testing.T) { + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Host: "oauth2.googleapis.com", + Path: "/token", + }, + } + + if !IsTokenExchange(req) { + t.Error("valid token exchange request should return true") + } +} + +func TestHandleTokenExchange_NilRequest(t *testing.T) { + tv := NewTokenVendor() + + resp := tv.HandleTokenExchange(nil) + if resp == nil { + t.Fatal("should not return nil response") + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestHandleTokenExchange_NilURL(t *testing.T) { + tv := NewTokenVendor() + + req := &http.Request{ + URL: nil, + } + + resp := tv.HandleTokenExchange(req) + if resp == nil { + t.Fatal("should not return nil response") + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestHandleTokenExchange_ValidRequest(t *testing.T) { + tv := NewTokenVendor() + + req := &http.Request{ + URL: &url.URL{ + Host: "oauth2.googleapis.com", + Path: "/token", + }, + } + + resp := tv.HandleTokenExchange(req) + if resp == nil { + t.Fatal("should not return nil response") + } + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + // Verify response body contains dummy token + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + defer resp.Body.Close() + + bodyStr := string(body) + if bodyStr == "" { + t.Error("response body should not be empty") + } + // Should contain the dummy token + if len(bodyStr) < 10 { + t.Errorf("response body suspiciously short: %q", bodyStr) + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index edfa477..e9f02cc 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -310,6 +310,14 @@ type Config struct { // New creates a configured goproxy server. func New(cfg Config) *http.Server { + // DEFENSIVE: Validate required configuration + if cfg.CA == nil { + log.Fatal("FATAL: proxy.New called with nil CA - this is a programming error") + } + if cfg.DomainFilter == nil { + log.Fatal("FATAL: proxy.New called with nil DomainFilter - this is a programming error") + } + proxy := goproxy.NewProxyHttpServer() // Override goproxy's default transport which uses InsecureSkipVerify: true. // We MUST verify upstream server TLS certificates to prevent credential theft via MITM. @@ -334,6 +342,11 @@ func New(cfg Config) *http.Server { // Handle CONNECT requests: client filter, port filtering, domain filtering, MITM proxy.OnRequest().HandleConnectFunc( func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + if ctx.Req == nil { + log.Printf("DEFENSIVE_CHECK: CONNECT handler received nil request for host=%s", host) + return rejectConnect, host + } + // Source IP filtering if cfg.ClientFilter != nil { srcIP := parseClientIP(ctx) @@ -378,6 +391,19 @@ func New(cfg Config) *http.Server { // - Suppress proxy identity headers proxy.OnRequest().DoFunc( func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + if req == nil { + log.Printf("DEFENSIVE_CHECK: DoFunc received nil request from client=%s", clientIP(ctx)) + return nil, nil + } + if req.URL == nil { + log.Printf("DEFENSIVE_CHECK: DoFunc received request with nil URL from client=%s", clientIP(ctx)) + return req, goproxy.NewResponse(req, + goproxy.ContentTypeText, + http.StatusBadRequest, + "Malformed request", + ) + } + // Source IP filtering for plain HTTP proxy requests. // MITM'd HTTPS requests already passed filtering in HandleConnectFunc. if cfg.ClientFilter != nil && req.URL.Scheme == "http" {