diff --git a/chaperone.go b/chaperone.go index fd16741..357ed44 100644 --- a/chaperone.go +++ b/chaperone.go @@ -144,6 +144,22 @@ func configureLogging(rc *runConfig, cfg *config.Config) { "env_var", "CHAPERONE_OBSERVABILITY_ENABLE_BODY_LOGGING", ) } + + // Notify the operator when target_addr logging is set to a non-default + // mode. "path" is informational; "full" is loud because query strings + // may contain secrets, tokens, or PII. + switch cfg.Observability.LogTargetAddr { + case observability.TargetAddrModeHost: + // default — no warning needed + case observability.TargetAddrModePath: + slog.Info("target_addr logging set to 'path' — request paths will appear in logs (host + path, no query)", + "env_var", "CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR", + ) + case observability.TargetAddrModeFull: + slog.Warn("target_addr logging set to 'full' — full target URLs including query parameters will appear in logs; query strings may contain secrets, tokens, or PII. Use only when explicitly required for audit/debugging", + "env_var", "CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR", + ) + } } // startProxy wires up the admin and proxy servers, starts them, and blocks @@ -222,10 +238,11 @@ func newProxyServer(plugin sdk.Plugin, rc *runConfig, cfg *config.Config, tracin KeyFile: cfg.Server.TLS.KeyFile, CAFile: cfg.Server.TLS.CAFile, }, - ReadTimeout: *cfg.Upstream.Timeouts.Read, - WriteTimeout: *cfg.Upstream.Timeouts.Write, - IdleTimeout: *cfg.Upstream.Timeouts.Idle, - TracingEnabled: tracingEnabled, + ReadTimeout: *cfg.Upstream.Timeouts.Read, + WriteTimeout: *cfg.Upstream.Timeouts.Write, + IdleTimeout: *cfg.Upstream.Timeouts.Idle, + TracingEnabled: tracingEnabled, + LogTargetAddrMode: cfg.Observability.LogTargetAddr, }) if err != nil { return nil, fmt.Errorf("creating proxy server: %w", err) diff --git a/chaperone_logging_test.go b/chaperone_logging_test.go new file mode 100644 index 0000000..cba3076 --- /dev/null +++ b/chaperone_logging_test.go @@ -0,0 +1,75 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package chaperone + +import ( + "bytes" + "strings" + "testing" + + "github.com/cloudblue/chaperone/internal/config" + "github.com/cloudblue/chaperone/internal/observability" +) + +// TestConfigureLogging_LogTargetAddrWarnings verifies that configureLogging +// emits an INFO when log_target_addr is set to "path" and a loud WARN when +// set to "full". These notify the operator that path/query data may now +// appear in logs (per LITE-34062). +func TestConfigureLogging_LogTargetAddrWarnings(t *testing.T) { + tests := []struct { + name string + mode string + wantLevel string + wantSubstring string + shouldNotMatch string + }{ + { + name: "host: no warning", + mode: "host", + wantLevel: "", + wantSubstring: "", + shouldNotMatch: "target_addr logging", + }, + { + name: "path: informational INFO", + mode: "path", + wantLevel: `"level":"INFO"`, + wantSubstring: "target_addr logging set to 'path'", + }, + { + name: "full: loud WARN", + mode: "full", + wantLevel: `"level":"WARN"`, + wantSubstring: "target_addr logging set to 'full'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + rc := &runConfig{logOutput: &buf} + cfg := &config.Config{} + cfg.Observability.LogLevel = "info" + cfg.Observability.LogTargetAddr = observability.TargetAddrMode(tt.mode) + + configureLogging(rc, cfg) + + out := buf.String() + if tt.shouldNotMatch != "" && strings.Contains(out, tt.shouldNotMatch) { + t.Errorf("host mode must not emit a target_addr warning, got: %s", out) + } + if tt.wantSubstring != "" { + if !strings.Contains(out, tt.wantSubstring) { + t.Errorf("expected log to contain %q, got: %s", tt.wantSubstring, out) + } + if !strings.Contains(out, tt.wantLevel) { + t.Errorf("expected level %q, got: %s", tt.wantLevel, out) + } + if !strings.Contains(out, `"env_var":"CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR"`) { + t.Errorf("expected env_var attribute in log, got: %s", out) + } + } + }) + } +} diff --git a/configs/config.example.yaml b/configs/config.example.yaml index e5f1530..d402db5 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -146,6 +146,31 @@ observability: # Env: CHAPERONE_OBSERVABILITY_ENABLE_TRACING enable_tracing: false + # Controls how much of the upstream target URL is reported in the + # `target_addr` log field. The same value is applied uniformly to every + # log line that references the target (request completed, upstream + # response, allow-list events, plugin errors, etc.). + # + # Valid values: + # "host" (default) - Authority only (host[:port]). No scheme, no path, + # no query. Safest. Example: "api.vendor.com:8443". + # "path" - scheme://host[:port]/path. Path appears, query + # stripped. Example: "https://api.vendor.com/v1/users". + # Useful for "what endpoint was called" auditing. + # "full" - Full URL including query. Userinfo (user:pass@) is + # always stripped, in every mode. + # Example: "https://api.vendor.com/v1/users?key=val". + # + # SECURITY: "path" and "full" can leak sensitive information. Path + # segments may contain PII (e.g., emails or IDs); query strings are a + # common location for tokens and API keys. The proxy emits a startup + # WARN when "full" is enabled. Use "host" unless an audit/debugging + # requirement justifies the extra detail. + # + # Default: "host" + # Env: CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR + log_target_addr: "host" + # Additional headers to redact from logs and strip from responses. # These are MERGED with the built-in defaults which are always included: # Authorization, Proxy-Authorization, Cookie, Set-Cookie, X-API-Key, X-Auth-Token diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index f4526e5..caff8b8 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -125,6 +125,7 @@ observability: log_level: "info" enable_profiling: false enable_tracing: false + log_target_addr: "host" sensitive_headers: - "X-Custom-Secret" - "X-Vendor-Token" @@ -136,6 +137,7 @@ observability: | `enable_profiling` | `CHAPERONE_OBSERVABILITY_ENABLE_PROFILING` | bool | `false` | Enable `/debug/pprof` endpoints on the admin port | | `enable_tracing` | `CHAPERONE_OBSERVABILITY_ENABLE_TRACING` | bool | `false` | Enable OpenTelemetry distributed tracing (see [Tracing](#tracing)) | | — | `CHAPERONE_OBSERVABILITY_ENABLE_BODY_LOGGING` | bool | `false` | Log request/response bodies at debug level. **Env-var only** — cannot be set in the YAML file (security safeguard). A startup warning is emitted when enabled. | +| `log_target_addr` | `CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR` | string | `host` | Detail level of the upstream target in the `target_addr` log field: `host` / `path` / `full`. See [Target Address Logging](#target-address-logging). | | `sensitive_headers` | — | []string | See below | Additional headers to redact (merged with defaults) | #### Sensitive Headers @@ -161,6 +163,32 @@ observability: - "X-Vendor-Token" ``` +#### Target Address Logging + +The `log_target_addr` setting controls how much of the upstream target URL +appears in the `target_addr` log field. The same value is applied uniformly +to every log line that references the target — `request completed`, +`upstream response`, allow-list events, plugin errors, and DEBUG breadcrumbs. + +| Mode | Output example | Use it when | +|------|---------------|-------------| +| `host` (default) | `api.vendor.com:8443` | Always safe. Authority only — no scheme, no path, no query. | +| `path` | `https://api.vendor.com:8443/v1/users` | You need to know which endpoint was called. Path appears; query is stripped. | +| `full` | `https://api.vendor.com:8443/v1/users?key=val` | You need full request audit/debugging. Includes the query string. | + +**Userinfo (`user:pass@host`) is always stripped, in every mode.** + +> **Security:** `path` and `full` may expose sensitive information. +> Path segments often carry IDs or PII (e.g., `/users/alice@example.com`); +> query strings are a common location for tokens and API keys. +> Chaperone emits a startup `WARN` when `full` is selected, and an +> informational `INFO` when `path` is selected. + +For full per-request audit trails without leaking sensitive data into +plaintext logs, prefer enabling [OpenTelemetry tracing](#tracing) — spans +include the full target URL but are exported through your observability +pipeline rather than the local log stream. + ### Tracing OpenTelemetry distributed tracing is controlled by two independent mechanisms: diff --git a/internal/config/config.go b/internal/config/config.go index 11e32d8..dfc90d4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,7 +3,11 @@ package config -import "time" +import ( + "time" + + "github.com/cloudblue/chaperone/internal/observability" +) // Config is the root configuration structure for Chaperone. type Config struct { @@ -99,6 +103,15 @@ type ObservabilityConfig struct { // environment variable, not via config file. A startup warning is emitted when enabled. // Per Design Spec Section 5.3 (Body Safety). EnableBodyLogging bool `yaml:"-"` + // LogTargetAddr controls how much of the upstream target URL is reported + // in the `target_addr` log field. Valid values: + // - "host" (default): authority only (host[:port]). Safe; no path/query. + // - "path": scheme://host[:port]/path. Path appears, query stripped. + // - "full": full URL including query. Userinfo always stripped. + // A startup warning is emitted for "path" and "full" — they may expose + // path segments or query parameters that contain PII or secrets. + // Default: "host". + LogTargetAddr observability.TargetAddrMode `yaml:"log_target_addr"` // SensitiveHeaders is the list of additional headers to redact from logs // and strip from responses. These are merged with the built-in defaults // (Authorization, Proxy-Authorization, Cookie, Set-Cookie, X-API-Key, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f1b547f..1dcbcd8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/cloudblue/chaperone/internal/observability" "github.com/cloudblue/chaperone/internal/router" ) @@ -944,6 +945,111 @@ observability: } } +// TestLoad_LogTargetAddr_DefaultsToHost verifies the secure default for the +// new log_target_addr field. An unset value must yield "host" — neither +// "path" nor "full" should ever be the default. +func TestLoad_LogTargetAddr_DefaultsToHost(t *testing.T) { + yamlContent := ` +server: + tls: + enabled: false +upstream: + allow_list: + api.example.com: + - "/**" +` + configPath := writeTestConfig(t, yamlContent) + + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Observability.LogTargetAddr != DefaultLogTargetAddr { + t.Errorf("LogTargetAddr = %q, want %q (secure default)", + cfg.Observability.LogTargetAddr, DefaultLogTargetAddr) + } +} + +// TestLoad_LogTargetAddr_YAMLAcceptsAllValidValues verifies that all three +// valid modes can be set from the YAML config. +func TestLoad_LogTargetAddr_YAMLAcceptsAllValidValues(t *testing.T) { + for _, mode := range observability.ValidTargetAddrModes { + t.Run(mode, func(t *testing.T) { + yamlContent := ` +server: + tls: + enabled: false +upstream: + allow_list: + api.example.com: + - "/**" +observability: + log_target_addr: "` + mode + `" +` + configPath := writeTestConfig(t, yamlContent) + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Observability.LogTargetAddr != observability.TargetAddrMode(mode) { + t.Errorf("LogTargetAddr = %q, want %q", cfg.Observability.LogTargetAddr, mode) + } + }) + } +} + +// TestLoad_LogTargetAddr_EnvOverridesYAML verifies the env var takes +// precedence over the YAML value. +func TestLoad_LogTargetAddr_EnvOverridesYAML(t *testing.T) { + yamlContent := ` +server: + tls: + enabled: false +upstream: + allow_list: + api.example.com: + - "/**" +observability: + log_target_addr: "host" +` + configPath := writeTestConfig(t, yamlContent) + t.Setenv("CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR", "full") + + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Observability.LogTargetAddr != "full" { + t.Errorf("LogTargetAddr = %q, want %q (env should override YAML)", + cfg.Observability.LogTargetAddr, "full") + } +} + +// TestLoad_LogTargetAddr_RejectsInvalidValue verifies that an unknown mode +// fails validation rather than silently falling back. +func TestLoad_LogTargetAddr_RejectsInvalidValue(t *testing.T) { + yamlContent := ` +server: + tls: + enabled: false +upstream: + allow_list: + api.example.com: + - "/**" +observability: + log_target_addr: "verbose" +` + configPath := writeTestConfig(t, yamlContent) + + _, err := Load(configPath) + if err == nil { + t.Fatal("expected validation error for invalid log_target_addr value, got nil") + } + if !errors.Is(err, ErrInvalidLogTargetAddr) { + t.Errorf("expected ErrInvalidLogTargetAddr, got: %v", err) + } +} + // TestLoad_EnableBodyLogging_DefaultFalse verifies the secure default. func TestLoad_EnableBodyLogging_DefaultFalse(t *testing.T) { yamlContent := ` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index bba3864..d037284 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -69,6 +69,10 @@ const ( DefaultEnableProfiling = false // DefaultEnableTracing is the secure default for tracing (disabled). DefaultEnableTracing = false + // DefaultLogTargetAddr is the default mode for the target_addr log field + // (host-only). Safe default — no path or query is exposed in logs unless + // the operator explicitly opts in to "path" or "full". + DefaultLogTargetAddr = "host" ) // defaultSensitiveHeaders returns the list of headers that MUST be redacted diff --git a/internal/config/loader.go b/internal/config/loader.go index 11807c6..84c7b55 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -11,6 +11,8 @@ import ( "time" "gopkg.in/yaml.v3" + + "github.com/cloudblue/chaperone/internal/observability" ) // EnvPrefix is the prefix for environment variable overrides. @@ -178,6 +180,9 @@ func applyObservabilityDefaults(cfg *ObservabilityConfig) { if cfg.LogLevel == "" { cfg.LogLevel = DefaultLogLevel } + if cfg.LogTargetAddr == "" { + cfg.LogTargetAddr = DefaultLogTargetAddr + } // EnableProfiling defaults to false (secure default), which is Go zero value // Security: Always merge user-provided sensitive headers with mandatory @@ -328,6 +333,9 @@ func applyObservabilityEnvOverrides(cfg *Config) error { } cfg.Observability.EnableBodyLogging = b } + if v := getEnv("OBSERVABILITY_LOG_TARGET_ADDR"); v != "" { + cfg.Observability.LogTargetAddr = observability.TargetAddrMode(v) + } return nil } diff --git a/internal/config/validate.go b/internal/config/validate.go index 09556a8..c1d6104 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -11,6 +11,7 @@ import ( "slices" "strconv" + "github.com/cloudblue/chaperone/internal/observability" "github.com/cloudblue/chaperone/internal/router" ) @@ -22,6 +23,8 @@ var ( ErrEmptyAllowList = errors.New("upstream.allow_list must not be empty") // ErrInvalidLogLevel is returned when log_level is not a valid value. ErrInvalidLogLevel = errors.New("observability.log_level must be one of: debug, info, warn, error") + // ErrInvalidLogTargetAddr is returned when log_target_addr is not a valid value. + ErrInvalidLogTargetAddr = errors.New("observability.log_target_addr must be one of: host, path, full") // ErrInvalidServerAddr is returned when server.addr is invalid. ErrInvalidServerAddr = errors.New("server.addr is invalid") // ErrInvalidAdminAddr is returned when server.admin_addr is invalid. @@ -228,10 +231,21 @@ func validateTimeouts(cfg *TimeoutConfig) error { // validateObservabilityConfig validates the observability configuration section. func validateObservabilityConfig(cfg *ObservabilityConfig) error { + var errs []error + if cfg.LogLevel != "" && !slices.Contains(ValidLogLevels, cfg.LogLevel) { - return fmt.Errorf("%w: got %q", ErrInvalidLogLevel, cfg.LogLevel) + errs = append(errs, fmt.Errorf("%w: got %q", ErrInvalidLogLevel, cfg.LogLevel)) + } + + if cfg.LogTargetAddr != "" { + if _, err := observability.ParseTargetAddrMode(string(cfg.LogTargetAddr)); err != nil { + errs = append(errs, fmt.Errorf("%w: got %q", ErrInvalidLogTargetAddr, cfg.LogTargetAddr)) + } } + if len(errs) > 0 { + return errors.Join(errs...) + } return nil } diff --git a/internal/observability/request_logger.go b/internal/observability/request_logger.go index 21d87a5..5dac22a 100644 --- a/internal/observability/request_logger.go +++ b/internal/observability/request_logger.go @@ -6,7 +6,6 @@ package observability import ( "log/slog" "net/http" - "net/url" "strings" "time" ) @@ -86,6 +85,8 @@ func (rc *ResponseCapturer) Status() int { // - logger: structured logger for output // - headerPrefix: the context header prefix (e.g., "X-Connect"). The middleware // constructs header names internally using the stable suffix constants. +// - addrMode: how to format the upstream target in the target_addr field +// (see TargetAddrMode). Empty defaults to host-only. // - next: the downstream handler // // Fields emitted: @@ -97,7 +98,8 @@ func (rc *ResponseCapturer) Status() int { // - vendor_id: Vendor ID from the -Vendor-ID request header // - marketplace_id: Marketplace ID from the -Marketplace-ID request header // - product_id: Product ID from the -Product-ID request header -// - target_host: Host extracted from the -Target-URL request header +// - target_addr: Upstream target formatted per addrMode (host / path / full). +// Sourced from the -Target-URL request header. // - client_ip: Client IP from proxy headers (X-Forwarded-For > X-Real-IP); // empty when no proxy headers are present (use remote_addr instead) // - remote_addr: Raw TCP peer address (always r.RemoteAddr, useful for @@ -107,7 +109,7 @@ func (rc *ResponseCapturer) Status() int { // already in the request context when this handler receives the request. // Uses defer to ensure logging occurs even if downstream handlers panic // (when used with panic recovery middleware). -func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, next http.Handler) http.Handler { +func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, addrMode TargetAddrMode, next http.Handler) http.Handler { vendorHdr := headerPrefix + "-Vendor-ID" marketplaceHdr := headerPrefix + "-Marketplace-ID" productHdr := headerPrefix + "-Product-ID" @@ -131,7 +133,7 @@ func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, next http "vendor_id", r.Header.Get(vendorHdr), "marketplace_id", r.Header.Get(marketplaceHdr), "product_id", r.Header.Get(productHdr), - "target_host", extractHost(r.Header.Get(targetURLHdr)), + "target_addr", FormatTargetAddr(r.Header.Get(targetURLHdr), addrMode), "client_ip", ClientIP(r), "remote_addr", r.RemoteAddr, ) @@ -141,16 +143,6 @@ func RequestLoggerMiddleware(logger *slog.Logger, headerPrefix string, next http }) } -// extractHost parses rawURL and returns only the host (with port if present). -// Returns an empty string if the URL is invalid or has no host. -func extractHost(rawURL string) string { - u, err := url.Parse(rawURL) - if err != nil || u.Host == "" { - return "" - } - return u.Host -} - // ClientIP extracts the client IP from proxy headers only. // Returns the first IP from X-Forwarded-For, or X-Real-IP as fallback. // Returns "" when no proxy headers are present — in that case the diff --git a/internal/observability/request_logger_test.go b/internal/observability/request_logger_test.go index 9d99e6a..b643240 100644 --- a/internal/observability/request_logger_test.go +++ b/internal/observability/request_logger_test.go @@ -26,7 +26,7 @@ type logEntry struct { VendorID string `json:"vendor_id"` MarketplaceID string `json:"marketplace_id"` ProductID string `json:"product_id"` - TargetHost string `json:"target_host"` + TargetAddr string `json:"target_addr"` ClientIP string `json:"client_ip"` RemoteAddr string `json:"remote_addr"` } @@ -134,7 +134,7 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) r := httptest.NewRequest(http.MethodPost, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "test-trace-123")) r.Header.Set("X-Connect-Vendor-ID", "microsoft") @@ -174,8 +174,8 @@ func TestRequestLoggerMiddleware_LogsRequestFields(t *testing.T) { if entry.ProductID != "PRD-001" { t.Errorf("product_id = %q, want %q", entry.ProductID, "PRD-001") } - if entry.TargetHost != "graph.microsoft.com" { - t.Errorf("target_host = %q, want %q", entry.TargetHost, "graph.microsoft.com") + if entry.TargetAddr != "graph.microsoft.com" { + t.Errorf("target_addr = %q, want %q", entry.TargetAddr, "graph.microsoft.com") } if entry.ClientIP != "" { t.Errorf("client_ip = %q, want empty (no proxy headers)", entry.ClientIP) @@ -193,7 +193,7 @@ func TestRequestLoggerMiddleware_CapturesErrorStatus(t *testing.T) { w.WriteHeader(http.StatusBadGateway) }) - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "err-trace")) w := httptest.NewRecorder() @@ -214,7 +214,7 @@ func TestRequestLoggerMiddleware_NoTraceID_LogsEmpty(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) // Deliberately no trace ID in context w := httptest.NewRecorder() @@ -236,7 +236,7 @@ func TestRequestLoggerMiddleware_LogsOnPanic(t *testing.T) { }) // Wrap with panic recovery INSIDE request logger, so logger still fires - handler := RequestLoggerMiddleware(logger, "X-Connect", panicRecoveryForTest(panicky)) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, panicRecoveryForTest(panicky)) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) r = r.WithContext(WithTraceID(r.Context(), "panic-trace")) w := httptest.NewRecorder() @@ -260,7 +260,7 @@ func TestRequestLoggerMiddleware_ClientIPFromXForwardedFor(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-For", "203.0.113.50") w := httptest.NewRecorder() @@ -285,7 +285,7 @@ func TestRequestLoggerMiddleware_TraceIDFromOuterMiddleware(t *testing.T) { }) // Production order: TraceID (outermost) → Logger → handler - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) handler = TraceIDMiddleware("X-Trace-ID", handler) r := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -311,7 +311,7 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) { }) // Production order: TraceID (outermost) → Logger → handler - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) handler = TraceIDMiddleware("X-Trace-ID", handler) r := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -326,7 +326,10 @@ func TestRequestLoggerMiddleware_TraceIDGenerated(t *testing.T) { } } -func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) { +// TestRequestLoggerMiddleware_HostMode_StripsPathAndQuery verifies the +// default mode emits only the authority and never lets path or query +// components leak into the log output. +func TestRequestLoggerMiddleware_HostMode_StripsPathAndQuery(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) @@ -334,7 +337,7 @@ func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) { w.WriteHeader(http.StatusOK) }) - handler := RequestLoggerMiddleware(logger, "X-Connect", inner) + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeHost, inner) r := httptest.NewRequest(http.MethodGet, "/proxy", nil) // URL with query string and path — only the host should appear in the log r.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/users?api_key=secret&token=abc") @@ -343,8 +346,8 @@ func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) { handler.ServeHTTP(w, r) entry := parseLogEntry(t, buf.Bytes()) - if entry.TargetHost != "api.vendor.com" { - t.Errorf("target_host = %q, want %q (only host, no path/query)", entry.TargetHost, "api.vendor.com") + if entry.TargetAddr != "api.vendor.com" { + t.Errorf("target_addr = %q, want %q (only host, no path/query)", entry.TargetAddr, "api.vendor.com") } // Verify query params did not leak into any logged field logOutput := buf.String() @@ -353,56 +356,57 @@ func TestRequestLoggerMiddleware_LogsTargetHost_ExtractsHostOnly(t *testing.T) { } } -func TestExtractHost(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "full URL with path", - input: "https://api.vendor.com/v1/users", - want: "api.vendor.com", - }, - { - name: "URL with port", - input: "https://api.vendor.com:8443/v1", - want: "api.vendor.com:8443", - }, - { - name: "URL with query string", - input: "https://api.vendor.com/v1?key=secret", - want: "api.vendor.com", - }, - { - name: "URL without path", - input: "https://api.vendor.com", - want: "api.vendor.com", - }, - { - name: "empty URL returns empty", - input: "", - want: "", - }, - { - name: "invalid URL returns empty", - input: "://invalid", - want: "", - }, - { - name: "path-only URL returns empty", - input: "/just/a/path", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := extractHost(tt.input) - if got != tt.want { - t.Errorf("extractHost(%q) = %q, want %q", tt.input, got, tt.want) - } - }) +// TestRequestLoggerMiddleware_PathMode_KeepsPathStripsQuery verifies path +// mode emits scheme + host + path but never includes the query string. +func TestRequestLoggerMiddleware_PathMode_KeepsPathStripsQuery(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModePath, inner) + r := httptest.NewRequest(http.MethodGet, "/proxy", nil) + r.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/users?api_key=secret") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + + entry := parseLogEntry(t, buf.Bytes()) + want := "https://api.vendor.com/v1/users" + if entry.TargetAddr != want { + t.Errorf("target_addr = %q, want %q (scheme+host+path, no query)", entry.TargetAddr, want) + } + if strings.Contains(buf.String(), "api_key") || strings.Contains(buf.String(), "secret") { + t.Errorf("path mode must not leak query params, got: %s", buf.String()) + } +} + +// TestRequestLoggerMiddleware_FullMode_KeepsQueryStripsUserinfo verifies +// full mode emits the query string but always strips userinfo. +func TestRequestLoggerMiddleware_FullMode_KeepsQueryStripsUserinfo(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := RequestLoggerMiddleware(logger, "X-Connect", TargetAddrModeFull, inner) + r := httptest.NewRequest(http.MethodGet, "/proxy", nil) + r.Header.Set("X-Connect-Target-URL", "https://user:pass@api.vendor.com/v1?key=val") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + + entry := parseLogEntry(t, buf.Bytes()) + want := "https://api.vendor.com/v1?key=val" + if entry.TargetAddr != want { + t.Errorf("target_addr = %q, want %q (full URL, userinfo stripped)", entry.TargetAddr, want) + } + if strings.Contains(buf.String(), "user:pass") { + t.Errorf("full mode must always strip userinfo, got: %s", buf.String()) } } diff --git a/internal/observability/target_addr.go b/internal/observability/target_addr.go new file mode 100644 index 0000000..1b99341 --- /dev/null +++ b/internal/observability/target_addr.go @@ -0,0 +1,101 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package observability + +import ( + "fmt" + "net/url" +) + +// TargetAddrMode controls how much detail of the upstream target URL is +// emitted in the `target_addr` log field. +// +// The mode is configured via observability.log_target_addr (or the +// CHAPERONE_OBSERVABILITY_LOG_TARGET_ADDR env var) and applied uniformly +// across every log line that reports the target. +type TargetAddrMode string + +const ( + // TargetAddrModeHost emits only the authority (host[:port]) of the target. + // No scheme, no path, no query. Default mode — minimum information, + // maximum safety. Example output: "api.vendor.com:8443". + TargetAddrModeHost TargetAddrMode = "host" + + // TargetAddrModePath emits scheme://host[:port]/path with the query + // string stripped. Useful for "what endpoint was called" auditing + // without leaking sensitive query parameters. + // Example output: "https://api.vendor.com:8443/v1/users". + TargetAddrModePath TargetAddrMode = "path" + + // TargetAddrModeFull emits the full URL including the query string + // but with userinfo (user:pass@) stripped. Use only when explicitly + // required for audit/debugging — query strings may contain secrets. + // Example output: "https://api.vendor.com:8443/v1/users?key=val". + TargetAddrModeFull TargetAddrMode = "full" +) + +// ValidTargetAddrModes is the list of accepted target_addr mode values. +var ValidTargetAddrModes = []string{ + string(TargetAddrModeHost), + string(TargetAddrModePath), + string(TargetAddrModeFull), +} + +// ParseTargetAddrMode returns the mode for the given string, or an error +// if the value is not one of host/path/full. +func ParseTargetAddrMode(s string) (TargetAddrMode, error) { + switch TargetAddrMode(s) { + case TargetAddrModeHost, TargetAddrModePath, TargetAddrModeFull: + return TargetAddrMode(s), nil + default: + return "", fmt.Errorf("invalid target_addr mode %q (valid: %v)", s, ValidTargetAddrModes) + } +} + +// FormatTargetAddr formats a raw URL string for logging according to mode. +// Returns "" when rawURL is empty, malformed, or has no host. Userinfo +// is always stripped, in every mode. +// +// Unknown modes default to TargetAddrModeHost — the safest fallback. This +// should never happen in practice because the config layer validates the +// mode before reaching the log site. +func FormatTargetAddr(rawURL string, mode TargetAddrMode) string { + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + return FormatTargetAddrFromURL(u, mode) +} + +// FormatTargetAddrFromURL is the same as FormatTargetAddr but accepts a +// pre-parsed *url.URL. Use this at sites that already parsed the URL +// (e.g. internal/proxy/server.go) to avoid re-parsing on every request. +// +// The input *url.URL is not mutated. Returns "" if u is nil or has no host. +func FormatTargetAddrFromURL(u *url.URL, mode TargetAddrMode) string { + if u == nil || u.Host == "" { + return "" + } + + switch mode { + case TargetAddrModeHost: + // Authority only — no allocation needed. + return u.Host + case TargetAddrModePath, TargetAddrModeFull: + // Clone to avoid mutating the caller's URL, then strip components + // according to mode. Userinfo and fragment are stripped in both modes; + // the query is stripped only in path mode. + cloned := *u + cloned.User = nil + cloned.Fragment = "" + cloned.RawFragment = "" + if mode == TargetAddrModePath { + cloned.RawQuery = "" + } + return cloned.String() + default: + // Unknown mode — fall back to host-only as a safe default. + return u.Host + } +} diff --git a/internal/observability/target_addr_test.go b/internal/observability/target_addr_test.go new file mode 100644 index 0000000..7877ab4 --- /dev/null +++ b/internal/observability/target_addr_test.go @@ -0,0 +1,244 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package observability + +import ( + "net/url" + "testing" +) + +func TestParseTargetAddrMode(t *testing.T) { + tests := []struct { + name string + input string + want TargetAddrMode + wantErr bool + }{ + {"host", "host", TargetAddrModeHost, false}, + {"path", "path", TargetAddrModePath, false}, + {"full", "full", TargetAddrModeFull, false}, + {"empty", "", "", true}, + {"unknown", "verbose", "", true}, + {"case-sensitive (uppercase rejected)", "HOST", "", true}, + {"trailing space rejected", "host ", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTargetAddrMode(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("ParseTargetAddrMode(%q) want error, got %q", tt.input, got) + } + return + } + if err != nil { + t.Errorf("ParseTargetAddrMode(%q) unexpected error: %v", tt.input, err) + } + if got != tt.want { + t.Errorf("ParseTargetAddrMode(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatTargetAddr(t *testing.T) { + tests := []struct { + name string + input string + mode TargetAddrMode + want string + }{ + // host mode — only authority, no scheme. + { + name: "host: bare host", + input: "https://api.vendor.com/v1/users", + mode: TargetAddrModeHost, + want: "api.vendor.com", + }, + { + name: "host: preserves port", + input: "https://api.vendor.com:8443/v1", + mode: TargetAddrModeHost, + want: "api.vendor.com:8443", + }, + { + name: "host: strips userinfo", + input: "https://user:pass@api.vendor.com/v1", + mode: TargetAddrModeHost, + want: "api.vendor.com", + }, + { + name: "host: strips query", + input: "https://api.vendor.com/v1?token=abc", + mode: TargetAddrModeHost, + want: "api.vendor.com", + }, + { + name: "host: no scheme in output", + input: "https://api.vendor.com:8443/v1", + mode: TargetAddrModeHost, + want: "api.vendor.com:8443", // distinguishes from "https://..." + }, + + // path mode — scheme + host + path, no query, no userinfo. + { + name: "path: full URL", + input: "https://api.vendor.com/v1/users", + mode: TargetAddrModePath, + want: "https://api.vendor.com/v1/users", + }, + { + name: "path: strips query", + input: "https://api.vendor.com/v1/users?api_key=secret&token=abc", + mode: TargetAddrModePath, + want: "https://api.vendor.com/v1/users", + }, + { + name: "path: strips userinfo", + input: "https://user:pass@api.vendor.com/v1", + mode: TargetAddrModePath, + want: "https://api.vendor.com/v1", + }, + { + name: "path: strips fragment", + input: "https://api.vendor.com/v1#section", + mode: TargetAddrModePath, + want: "https://api.vendor.com/v1", + }, + { + name: "path: no path keeps host root", + input: "https://api.vendor.com", + mode: TargetAddrModePath, + want: "https://api.vendor.com", + }, + { + name: "path: preserves port", + input: "https://api.vendor.com:8443/v1/users", + mode: TargetAddrModePath, + want: "https://api.vendor.com:8443/v1/users", + }, + { + name: "path: keeps sensitive path segments (operator opted in)", + input: "https://api.vendor.com/users/alice@example.com", + mode: TargetAddrModePath, + want: "https://api.vendor.com/users/alice@example.com", + }, + + // full mode — everything except userinfo and fragment. + { + name: "full: preserves query", + input: "https://api.vendor.com/v1?key=val&token=abc", + mode: TargetAddrModeFull, + want: "https://api.vendor.com/v1?key=val&token=abc", + }, + { + name: "full: strips userinfo", + input: "https://user:pass@api.vendor.com/v1?key=val", + mode: TargetAddrModeFull, + want: "https://api.vendor.com/v1?key=val", + }, + { + name: "full: strips fragment", + input: "https://api.vendor.com/v1?key=val#section", + mode: TargetAddrModeFull, + want: "https://api.vendor.com/v1?key=val", + }, + { + name: "full: combined — strips userinfo+fragment, keeps everything else", + input: "https://u:p@api.vendor.com:8443/v1/users/alice@example.com?token=abc#frag", + mode: TargetAddrModeFull, + want: "https://api.vendor.com:8443/v1/users/alice@example.com?token=abc", + }, + + // invalid input — return empty string, never panic. + {"invalid: parse error", "://invalid", TargetAddrModeFull, ""}, + {"invalid: empty", "", TargetAddrModeFull, ""}, + {"invalid: relative URL with no host", "/just/a/path", TargetAddrModeHost, ""}, + {"invalid: path-only no host (path mode)", "/users/123", TargetAddrModePath, ""}, + + // unknown mode falls back to host (safe default). + { + name: "unknown mode: falls back to host-only", + input: "https://api.vendor.com/v1?token=abc", + mode: TargetAddrMode("unknown"), + want: "api.vendor.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatTargetAddr(tt.input, tt.mode) + if got != tt.want { + t.Errorf("FormatTargetAddr(%q, %q) = %q, want %q", tt.input, tt.mode, got, tt.want) + } + }) + } +} + +// TestFormatTargetAddrFromURL_NilSafe verifies the nil-URL guard. +func TestFormatTargetAddrFromURL_NilSafe(t *testing.T) { + if got := FormatTargetAddrFromURL(nil, TargetAddrModeHost); got != "" { + t.Errorf("FormatTargetAddrFromURL(nil, host) = %q, want %q", got, "") + } + if got := FormatTargetAddrFromURL(&url.URL{}, TargetAddrModeFull); got != "" { + t.Errorf("FormatTargetAddrFromURL(empty URL, full) = %q, want %q", got, "") + } +} + +// TestFormatTargetAddrFromURL_DoesNotMutate verifies that the caller's +// *url.URL is not modified by the helper (important — server.go reuses the +// parsed URL for forwarding). +func TestFormatTargetAddrFromURL_DoesNotMutate(t *testing.T) { + original := "https://user:pass@api.vendor.com/v1?key=val#frag" + u, err := url.Parse(original) + if err != nil { + t.Fatalf("url.Parse failed: %v", err) + } + + for _, mode := range []TargetAddrMode{TargetAddrModeHost, TargetAddrModePath, TargetAddrModeFull} { + _ = FormatTargetAddrFromURL(u, mode) + } + + if u.String() != original { + t.Errorf("FormatTargetAddrFromURL mutated input URL: got %q, want %q", u.String(), original) + } + if u.User == nil || u.User.String() != "user:pass" { + t.Errorf("FormatTargetAddrFromURL stripped userinfo from caller's URL: %v", u.User) + } + if u.RawQuery != "key=val" { + t.Errorf("FormatTargetAddrFromURL stripped query from caller's URL: %q", u.RawQuery) + } + if u.Fragment != "frag" { + t.Errorf("FormatTargetAddrFromURL stripped fragment from caller's URL: %q", u.Fragment) + } +} + +// TestFormatTargetAddrFromURL_EquivalentToParsing verifies that the URL and +// raw-string variants produce identical output for valid URLs. This is the +// invariant that lets the proxy use either form interchangeably. +func TestFormatTargetAddrFromURL_EquivalentToParsing(t *testing.T) { + rawURLs := []string{ + "https://api.vendor.com/v1/users", + "https://api.vendor.com:8443/v1/users?key=val", + "https://user:pass@api.vendor.com/v1?token=abc#frag", + "http://localhost:9999/echo", + } + modes := []TargetAddrMode{TargetAddrModeHost, TargetAddrModePath, TargetAddrModeFull} + + for _, raw := range rawURLs { + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("url.Parse(%q) failed: %v", raw, err) + } + for _, mode := range modes { + fromRaw := FormatTargetAddr(raw, mode) + fromURL := FormatTargetAddrFromURL(u, mode) + if fromRaw != fromURL { + t.Errorf("FormatTargetAddr(%q, %q) = %q, but FormatTargetAddrFromURL(parsed, %q) = %q (must agree)", + raw, mode, fromRaw, mode, fromURL) + } + } + } +} diff --git a/internal/proxy/middleware_bench_test.go b/internal/proxy/middleware_bench_test.go index 4e257da..1899739 100644 --- a/internal/proxy/middleware_bench_test.go +++ b/internal/proxy/middleware_bench_test.go @@ -38,7 +38,7 @@ func BenchmarkRequestLoggingMiddleware(b *testing.B) { inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - handler := observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", inner) + handler := observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", observability.TargetAddrModeHost, inner) req := httptest.NewRequest("GET", "/proxy", nil) b.ReportAllocs() @@ -61,7 +61,7 @@ func BenchmarkMiddlewareStack(b *testing.B) { // Stack middlewares as they would be in production // Order: TraceID (outermost) → Logger → PanicRecovery → handler handler := PanicRecoveryMiddleware(inner) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", observability.TargetAddrModeHost, handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) req := httptest.NewRequest("GET", "/proxy", nil) @@ -82,7 +82,7 @@ func BenchmarkMiddlewareStack_Parallel(b *testing.B) { w.WriteHeader(http.StatusOK) }) handler := PanicRecoveryMiddleware(inner) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", observability.TargetAddrModeHost, handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) b.ReportAllocs() diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 660d448..53c0208 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -83,6 +83,11 @@ type Config struct { // TracingEnabled controls whether OpenTelemetry tracing middleware is active. TracingEnabled bool + // LogTargetAddrMode controls how the upstream target appears in log + // output (the `target_addr` field). See observability.TargetAddrMode. + // Empty defaults to host-only, the safest behavior. + LogTargetAddrMode observability.TargetAddrMode + // Timeouts ReadTimeout time.Duration WriteTimeout time.Duration @@ -245,6 +250,7 @@ func (s *Server) Handler() http.Handler { router.NewAllowListMiddleware( s.config.AllowList, s.config.HeaderPrefix, + s.config.LogTargetAddrMode, http.HandlerFunc(s.handleProxy), ), ), @@ -478,7 +484,7 @@ func (s *Server) withMiddleware(handler http.Handler) http.Handler { ) } - handler = observability.RequestLoggerMiddleware(slog.Default(), s.config.HeaderPrefix, handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), s.config.HeaderPrefix, s.config.LogTargetAddrMode, handler) handler = observability.TraceIDMiddleware(s.config.TraceHeader, handler) return handler } @@ -510,17 +516,25 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { return } + // Parse the target URL once. If parsing fails, target_addr defaults to "" + // (consistent with FormatTargetAddr's behavior for malformed input) so + // the DEBUG breadcrumb still fires before the bad-request response. + targetURL, parseErr := url.Parse(txCtx.TargetURL) + var targetAddr string + if parseErr == nil { + targetAddr = observability.FormatTargetAddrFromURL(targetURL, s.config.LogTargetAddrMode) + } + slog.Debug("transaction context parsed", "trace_id", traceID, "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", extractTargetHost(txCtx.TargetURL), + "target_addr", targetAddr, ) - targetURL, err := url.Parse(txCtx.TargetURL) - if err != nil { - s.respondBadRequest(w, traceID, "invalid target URL", err) + if parseErr != nil { + s.respondBadRequest(w, traceID, "invalid target URL", parseErr) return } @@ -530,7 +544,7 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { slog.Warn("insecure target URL rejected", "trace_id", traceID, "target_scheme", targetURL.Scheme, - "target_host", targetURL.Host, + "target_addr", targetAddr, ) http.Error(w, "Bad Request: "+err.Error(), http.StatusBadRequest) return @@ -540,18 +554,18 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { if targetURL.Scheme == "http" { slog.Warn("forwarding to insecure HTTP target - DEVELOPMENT ONLY", "trace_id", traceID, - "target_host", targetURL.Host, + "target_addr", targetAddr, ) } - r, err = s.injectCredentials(r, txCtx, targetURL.Host) + r, err = s.injectCredentials(r, txCtx, targetAddr) if err != nil { - s.handlePluginError(w, traceID, txCtx, targetURL.Host, err) + s.handlePluginError(w, traceID, txCtx, targetAddr, err) return } //nolint:contextcheck // ModifyResponse uses resp.Request.Context() internally - s.forwardRequest(w, r, targetURL, traceID, txCtx) + s.forwardRequest(w, r, targetURL, traceID, txCtx, targetAddr) } // respondBadRequest logs and responds with a 400 Bad Request. @@ -572,13 +586,14 @@ func (s *Server) respondBadRequest(w http.ResponseWriter, traceID, msg string, e // RedactingHandler can detect and redact them if they leak into log output // (value-based scanning, Layers 3 & 4). // -// targetHost is the already-parsed host from the target URL, passed from handleProxy -// to avoid re-parsing and to keep the field in DEBUG log output. +// targetAddr is the pre-formatted target address (per LogTargetAddrMode), +// passed from handleProxy to avoid re-formatting and to keep the field +// consistent in DEBUG log output across all sites. // // Returns the (possibly updated) request and any error. The caller MUST use // the returned request for all subsequent operations, because the context may // have been enriched with secret values and injected header keys. -func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContext, targetHost string) (*http.Request, error) { +func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContext, targetAddr string) (*http.Request, error) { if s.config.Plugin == nil { return r, nil } @@ -618,7 +633,7 @@ func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContex "trace_id", txCtx.TraceID, "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, - "target_host", targetHost, + "target_addr", targetAddr, "credential_path", "fast", "injected_header_count", len(cred.Headers), "plugin_duration_ms", pluginDuration.Milliseconds(), @@ -637,7 +652,7 @@ func (s *Server) injectCredentials(r *http.Request, txCtx *sdk.TransactionContex "trace_id", txCtx.TraceID, "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, - "target_host", targetHost, + "target_addr", targetAddr, "credential_path", "slow", "injected_header_count", injectedCount, "plugin_duration_ms", pluginDuration.Milliseconds(), @@ -701,17 +716,6 @@ func (s *Server) detectSlowPathInjections(r *http.Request, before http.Header) ( return r, len(injectedKeys) } -// extractTargetHost parses rawURL and returns only the host (with port if present). -// Used in log output to avoid leaking sensitive path or query information. -// Returns an empty string if the URL is invalid or has no host. -func extractTargetHost(rawURL string) string { - u, err := url.Parse(rawURL) - if err != nil || u.Host == "" { - return "" - } - return u.Host -} - // headerValuesEqual returns true if two header value slices are identical. func headerValuesEqual(a, b []string) bool { if len(a) != len(b) { @@ -735,12 +739,12 @@ func (s *Server) stripContextHeaders(req *http.Request) { } // forwardRequest forwards the request to the target URL via reverse proxy. -func (s *Server) forwardRequest(w http.ResponseWriter, r *http.Request, target *url.URL, traceID string, txCtx *sdk.TransactionContext) { +func (s *Server) forwardRequest(w http.ResponseWriter, r *http.Request, target *url.URL, traceID string, txCtx *sdk.TransactionContext, targetAddr string) { // Record upstream timing for both telemetry metrics and Server-Timing header telTiming := telemetry.TimingFromContext(r.Context()) upstreamStart := time.Now() - proxy := s.createReverseProxy(target, traceID, txCtx, telTiming, upstreamStart) + proxy := s.createReverseProxy(target, traceID, txCtx, telTiming, upstreamStart, targetAddr) proxy.ServeHTTP(w, r) // #nosec G704 -- target validated against allow-list in handleProxy before reaching here } @@ -749,14 +753,14 @@ func (s *Server) forwardRequest(w http.ResponseWriter, r *http.Request, target * const StatusClientClosedRequest = 499 // handlePluginError handles errors from the plugin. -func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx *sdk.TransactionContext, targetHost string, err error) { +func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx *sdk.TransactionContext, targetAddr string, err error) { if errors.Is(err, context.DeadlineExceeded) { slog.Error("plugin timeout", "trace_id", traceID, "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", targetHost, + "target_addr", targetAddr, "error", err, ) http.Error(w, "Gateway Timeout", http.StatusGatewayTimeout) @@ -769,7 +773,7 @@ func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", targetHost, + "target_addr", targetAddr, ) // Write 499 so RequestLoggerMiddleware logs the correct status instead of // the default 200. Do NOT write a body — the client is already gone. @@ -782,7 +786,7 @@ func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", targetHost, + "target_addr", targetAddr, "error", err, ) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -799,7 +803,7 @@ func (s *Server) handlePluginError(w http.ResponseWriter, traceID string, txCtx // Exactly one of ModifyResponse or ErrorHandler fires per request — never both. // //nolint:contextcheck // ErrorHandler signature is defined by httputil.ReverseProxy; context is accessed via r.Context() -func (s *Server) createReverseProxy(target *url.URL, traceID string, txCtx *sdk.TransactionContext, telTiming *telemetry.Timing, upstreamStart time.Time) *httputil.ReverseProxy { +func (s *Server) createReverseProxy(target *url.URL, traceID string, txCtx *sdk.TransactionContext, telTiming *telemetry.Timing, upstreamStart time.Time, targetAddr string) *httputil.ReverseProxy { proxy := httputil.NewSingleHostReverseProxy(target) // #nosec G704 -- target validated against allow-list in handleProxy before reaching here // Apply upstream transport with configurable timeouts. @@ -832,7 +836,7 @@ func (s *Server) createReverseProxy(target *url.URL, traceID string, txCtx *sdk. } // Response modification chain: Timing → Plugin → Strip Headers → Error Normalization - proxy.ModifyResponse = s.buildModifyResponse(traceID, txCtx, telTiming, upstreamStart) //nolint:bodyclose // resp.Body is managed by httputil.ReverseProxy + proxy.ModifyResponse = s.buildModifyResponse(traceID, txCtx, telTiming, upstreamStart, targetAddr) //nolint:bodyclose // resp.Body is managed by httputil.ReverseProxy // Handle proxy errors (upstream unreachable, connection refused, etc.) // ErrorHandler fires instead of ModifyResponse when the upstream is unreachable. @@ -856,7 +860,7 @@ func (s *Server) createReverseProxy(target *url.URL, traceID string, txCtx *sdk. "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", target.Host, + "target_addr", targetAddr, "error", err, ) http.Error(w, "Bad Gateway", http.StatusBadGateway) @@ -874,7 +878,7 @@ func (s *Server) upstreamTransport() http.RoundTripper { // buildModifyResponse creates the response modification closure that runs // the timing → plugin → strip headers → error normalization chain. -func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionContext, telTiming *telemetry.Timing, upstreamStart time.Time) func(*http.Response) error { +func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionContext, telTiming *telemetry.Timing, upstreamStart time.Time, targetAddr string) func(*http.Response) error { return func(resp *http.Response) error { // Step 0a: Record upstream duration for telemetry metrics (safe across goroutines) telemetry.RecordUpstreamDuration(resp.Request.Context(), telTiming) @@ -902,7 +906,7 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", resp.Request.URL.Host, + "target_addr", targetAddr, "error", err, ) // Continue with response processing even if plugin fails @@ -919,7 +923,7 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte security.StripInjectedHeaders(resp.Request.Context(), resp.Header) // Step 3: Core error normalization (safety net - unless plugin opted out) - s.applyErrorNormalization(traceID, txCtx, resp, action) + s.applyErrorNormalization(traceID, txCtx, resp, action, targetAddr) slog.Info("upstream response", "trace_id", traceID, @@ -928,7 +932,7 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", resp.Request.URL.Host, + "target_addr", targetAddr, ) return nil } @@ -937,7 +941,7 @@ func (s *Server) buildModifyResponse(traceID string, txCtx *sdk.TransactionConte // applyErrorNormalization runs Step 3 of the response modification chain. // If the plugin opted out via ResponseAction.SkipErrorNormalization, it logs // the opt-out at DEBUG and skips. Otherwise it runs the core error normalization. -func (s *Server) applyErrorNormalization(traceID string, txCtx *sdk.TransactionContext, resp *http.Response, action *sdk.ResponseAction) { +func (s *Server) applyErrorNormalization(traceID string, txCtx *sdk.TransactionContext, resp *http.Response, action *sdk.ResponseAction, targetAddr string) { if action != nil && action.SkipErrorNormalization { slog.Debug("plugin opted out of error normalization", "trace_id", traceID, @@ -950,7 +954,7 @@ func (s *Server) applyErrorNormalization(traceID string, txCtx *sdk.TransactionC "vendor_id", txCtx.VendorID, "marketplace_id", txCtx.MarketplaceID, "product_id", txCtx.ProductID, - "target_host", resp.Request.URL.Host, + "target_addr", targetAddr, "error", err, ) // Continue even if normalization fails - response will be sent as-is diff --git a/internal/proxy/server_helpers_test.go b/internal/proxy/server_helpers_test.go deleted file mode 100644 index 818aba7..0000000 --- a/internal/proxy/server_helpers_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2026 CloudBlue LLC -// SPDX-License-Identifier: Apache-2.0 - -package proxy - -import "testing" - -func TestExtractTargetHost(t *testing.T) { - tests := []struct { - name string - input string - want string - }{ - { - name: "returns host only, strips path", - input: "https://api.vendor.com/v1/users/123", - want: "api.vendor.com", - }, - { - name: "returns host only, strips query string", - input: "https://api.vendor.com/v1?api_key=secret", - want: "api.vendor.com", - }, - { - name: "returns host only, strips userinfo", - input: "https://user:pass@api.vendor.com/v1", - want: "api.vendor.com", - }, - { - name: "preserves port", - input: "https://api.vendor.com:8443/v1", - want: "api.vendor.com:8443", - }, - { - name: "strips all sensitive parts together", - input: "https://user:pass@api.vendor.com/v1/users/alice@example.com?token=abc#frag", - want: "api.vendor.com", - }, - { - name: "invalid URL returns empty", - input: "://invalid", - want: "", - }, - { - name: "empty URL returns empty", - input: "", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := extractTargetHost(tt.input) - if got != tt.want { - t.Errorf("extractTargetHost(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index db34f1c..34f011b 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -375,7 +375,7 @@ func TestMiddlewareStack_PanicLogsCorrectStatus(t *testing.T) { // Apply middleware in production order: TraceID → Logger → PanicRecovery → handler handler := proxy.PanicRecoveryMiddleware(panicHandler) - handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", handler) + handler = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", observability.TargetAddrModeHost, handler) handler = observability.TraceIDMiddleware("Connect-Request-ID", handler) req := httptest.NewRequest(http.MethodPost, "/test/panic", nil) @@ -420,7 +420,7 @@ func TestMiddlewareStack_NormalRequestLogsCorrectStatus(t *testing.T) { // Apply middleware in production order: TraceID → Logger → PanicRecovery → handler wrapped := proxy.PanicRecoveryMiddleware(handler) - wrapped = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", wrapped) + wrapped = observability.RequestLoggerMiddleware(slog.Default(), "X-Connect", observability.TargetAddrModeHost, wrapped) wrapped = observability.TraceIDMiddleware("Connect-Request-ID", wrapped) req := httptest.NewRequest(http.MethodPost, "/resource", nil) diff --git a/internal/proxy/target_addr_integration_test.go b/internal/proxy/target_addr_integration_test.go new file mode 100644 index 0000000..31fcc31 --- /dev/null +++ b/internal/proxy/target_addr_integration_test.go @@ -0,0 +1,271 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy_test + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/cloudblue/chaperone/internal/observability" + "github.com/cloudblue/chaperone/sdk" +) + +// ============================================================================= +// Integration tests for the configurable target_addr field (LITE-34062). +// +// These tests verify the three log_target_addr modes end-to-end through the +// full middleware stack (allow-list, request logger, proxy handlers) and +// guard the cross-site consistency invariant — every log line that emits +// target_addr must produce identical content per request. +// ============================================================================= + +// targetAddrTestSetup runs a request through the full proxy stack with the +// given LogTargetAddrMode and returns the captured JSON log lines. +func targetAddrTestSetup(t *testing.T, mode observability.TargetAddrMode, target, requestURL string) []map[string]any { + t.Helper() + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + original := slog.Default() + slog.SetDefault(logger) + t.Cleanup(func() { slog.SetDefault(original) }) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(backend.Close) + + plugin := &mockPlugin{ + getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + return &sdk.Credential{ + Headers: map[string]string{"Authorization": "Bearer test"}, + ExpiresAt: time.Now().Add(time.Hour), + }, nil + }, + } + + cfg := testConfig() + cfg.Plugin = plugin + cfg.LogTargetAddrMode = mode + srv := mustNewServerForTarget(t, cfg, target) + handler := srv.Handler() + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", requestURL) + req.Header.Set("X-Connect-Vendor-ID", "VA-test") + req.Header.Set("X-Connect-Marketplace-ID", "MP-US") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d. body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + return parseJSONLogLines(t, buf.Bytes()) +} + +func parseJSONLogLines(t *testing.T, b []byte) []map[string]any { + t.Helper() + var out []map[string]any + for _, line := range bytes.Split(b, []byte{'\n'}) { + if len(bytes.TrimSpace(line)) == 0 { + continue + } + var m map[string]any + if err := json.Unmarshal(line, &m); err != nil { + t.Fatalf("invalid JSON log line %q: %v", line, err) + } + out = append(out, m) + } + return out +} + +// targetAddrFromLogs collects all distinct non-empty target_addr values. +func targetAddrFromLogs(lines []map[string]any) []string { + var values []string + seen := map[string]struct{}{} + for _, line := range lines { + v, ok := line["target_addr"].(string) + if !ok || v == "" { + continue + } + if _, dup := seen[v]; !dup { + seen[v] = struct{}{} + values = append(values, v) + } + } + return values +} + +// joinedLog returns the entire buffer as a single string for substring asserts. +func joinedLog(lines []map[string]any) string { + parts := make([]string, 0, len(lines)) + for _, line := range lines { + b, _ := json.Marshal(line) + parts = append(parts, string(b)) + } + return strings.Join(parts, "\n") +} + +// TestLogTargetAddr_HostMode_OnlyAuthority verifies the default mode emits +// only the authority and never lets path/query/userinfo leak into any log line. +func TestLogTargetAddr_HostMode_OnlyAuthority(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + requestURL := backend.URL + "/v1/users/alice@example.com?api_key=SHOULD_NOT_APPEAR&token=SECRET" + + lines := targetAddrTestSetup(t, observability.TargetAddrModeHost, backend.URL, requestURL) + + addrs := targetAddrFromLogs(lines) + if len(addrs) == 0 { + t.Fatalf("no target_addr field found in any log line, got %d lines", len(lines)) + } + for _, v := range addrs { + if strings.Contains(v, "://") { + t.Errorf("target_addr in host mode must not contain scheme, got %q", v) + } + if strings.Contains(v, "/") { + t.Errorf("target_addr in host mode must not contain path, got %q", v) + } + if strings.Contains(v, "?") { + t.Errorf("target_addr in host mode must not contain query, got %q", v) + } + } + + full := joinedLog(lines) + for _, leak := range []string{"SHOULD_NOT_APPEAR", "SECRET", "alice@example.com", "/v1/users", "api_key", "token="} { + if strings.Contains(full, leak) { + t.Errorf("host mode log must not contain %q, got: %s", leak, full) + } + } +} + +// TestLogTargetAddr_PathMode_StripsQuery verifies that path mode emits +// scheme+host+path while still stripping the query string. +func TestLogTargetAddr_PathMode_StripsQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + requestURL := backend.URL + "/v1/users?api_key=SHOULD_NOT_APPEAR&token=SECRET" + + lines := targetAddrTestSetup(t, observability.TargetAddrModePath, backend.URL, requestURL) + + addrs := targetAddrFromLogs(lines) + if len(addrs) == 0 { + t.Fatalf("no target_addr field found, got %d lines", len(lines)) + } + for _, v := range addrs { + if !strings.Contains(v, "://") { + t.Errorf("target_addr in path mode must contain scheme, got %q", v) + } + if !strings.Contains(v, "/v1/users") { + t.Errorf("target_addr in path mode must contain path /v1/users, got %q", v) + } + if strings.Contains(v, "?") { + t.Errorf("target_addr in path mode must not contain query, got %q", v) + } + } + + full := joinedLog(lines) + for _, leak := range []string{"SHOULD_NOT_APPEAR", "SECRET", "api_key", "token="} { + if strings.Contains(full, leak) { + t.Errorf("path mode log must not contain %q (query value), got: %s", leak, full) + } + } +} + +// TestLogTargetAddr_FullMode_KeepsQueryStripsUserinfo verifies that full +// mode emits the query string and ALWAYS strips userinfo. +func TestLogTargetAddr_FullMode_KeepsQueryStripsUserinfo(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + // Build a target URL with userinfo embedded. + requestURL := strings.Replace(backend.URL, "://", "://USER:PASS_NEVER_IN_LOGS@", 1) + "/v1/users?key=val&token=APPEARS_HERE" + + lines := targetAddrTestSetup(t, observability.TargetAddrModeFull, backend.URL, requestURL) + + addrs := targetAddrFromLogs(lines) + if len(addrs) == 0 { + t.Fatalf("no target_addr field found, got %d lines", len(lines)) + } + for _, v := range addrs { + if !strings.Contains(v, "?key=val") { + t.Errorf("target_addr in full mode must contain query, got %q", v) + } + if strings.Contains(v, "USER") || strings.Contains(v, "PASS_NEVER_IN_LOGS") { + t.Errorf("target_addr in full mode must always strip userinfo, got %q", v) + } + } + + full := joinedLog(lines) + if !strings.Contains(full, "APPEARS_HERE") { + t.Errorf("full mode must include query value, got: %s", full) + } + for _, leak := range []string{"USER", "PASS_NEVER_IN_LOGS"} { + if strings.Contains(full, leak) { + t.Errorf("userinfo must NEVER appear in any log mode, found %q in: %s", leak, full) + } + } +} + +// TestLogTargetAddr_AllSitesAgree guards the consistency invariant: for a +// single request, every log line that reports target_addr must contain the +// same string. A future regression that re-formats the field at one site +// (e.g. a new logger using a stale helper) is caught here. +func TestLogTargetAddr_AllSitesAgree(t *testing.T) { + for _, mode := range []observability.TargetAddrMode{ + observability.TargetAddrModeHost, + observability.TargetAddrModePath, + observability.TargetAddrModeFull, + } { + t.Run(string(mode), func(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + lines := targetAddrTestSetup(t, mode, backend.URL, backend.URL+"/v1/users?key=val") + + addrs := targetAddrFromLogs(lines) + if len(addrs) < 1 { + t.Fatalf("expected at least one target_addr value, got 0") + } + if len(addrs) > 1 { + t.Errorf("target_addr must be byte-identical across all log sites for a single request, got %d distinct values: %v", + len(addrs), addrs) + } + + // Sanity: at least the request_completed line and the upstream_response + // line must be present. + expectedMsgs := []string{"request completed", "upstream response"} + for _, msg := range expectedMsgs { + found := false + for _, line := range lines { + if line["msg"] == msg { + if v, _ := line["target_addr"].(string); v == "" { + t.Errorf("log line %q missing target_addr field", msg) + } + found = true + } + } + if !found { + t.Errorf("expected a log line with msg=%q, none found", msg) + } + } + }) + } +} diff --git a/internal/router/middleware.go b/internal/router/middleware.go index c07ad2c..1f25f87 100644 --- a/internal/router/middleware.go +++ b/internal/router/middleware.go @@ -8,7 +8,6 @@ import ( "errors" "log/slog" "net/http" - "net/url" "github.com/cloudblue/chaperone/internal/observability" ) @@ -18,6 +17,7 @@ import ( type AllowListMiddleware struct { validator *AllowListValidator headerPrefix string + addrMode observability.TargetAddrMode next http.Handler } @@ -26,11 +26,14 @@ type AllowListMiddleware struct { // Parameters: // - allowList: The host-to-paths mapping from configuration // - headerPrefix: The prefix for context headers (e.g., "X-Connect") +// - addrMode: how to format the target_addr field in WARN/DEBUG logs +// (see observability.TargetAddrMode). Empty defaults to host-only. // - next: The next handler in the chain -func NewAllowListMiddleware(allowList map[string][]string, headerPrefix string, next http.Handler) *AllowListMiddleware { +func NewAllowListMiddleware(allowList map[string][]string, headerPrefix string, addrMode observability.TargetAddrMode, next http.Handler) *AllowListMiddleware { return &AllowListMiddleware{ validator: NewAllowListValidator(allowList), headerPrefix: headerPrefix, + addrMode: addrMode, next: next, } } @@ -57,7 +60,7 @@ func (m *AllowListMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) if err := m.validator.Validate(targetURL); err != nil { slog.Warn("allow list validation failed", "trace_id", observability.TraceIDFromContext(r.Context()), - "target_host", extractHostFromURL(targetURL), + "target_addr", observability.FormatTargetAddr(targetURL, m.addrMode), "error", err.Error(), "remote_addr", r.RemoteAddr, ) @@ -79,23 +82,13 @@ func (m *AllowListMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) slog.Debug("allow list validation passed", "trace_id", observability.TraceIDFromContext(r.Context()), - "target_host", extractHostFromURL(targetURL), + "target_addr", observability.FormatTargetAddr(targetURL, m.addrMode), ) // Validation passed, continue to next handler m.next.ServeHTTP(w, r) } -// extractHostFromURL parses a URL string and returns only the host portion. -// Returns an empty string if the URL is invalid or has no host. -func extractHostFromURL(rawURL string) string { - u, err := url.Parse(rawURL) - if err != nil || u.Host == "" { - return "" - } - return u.Host -} - // errorResponse is the JSON structure for error responses. type errorResponse struct { Error string `json:"error"` diff --git a/internal/router/middleware_test.go b/internal/router/middleware_test.go index 9250db1..2a96f78 100644 --- a/internal/router/middleware_test.go +++ b/internal/router/middleware_test.go @@ -25,7 +25,7 @@ func TestAllowListMiddleware_ValidRequest(t *testing.T) { _, _ = w.Write([]byte("OK")) }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://api.example.com/v1/customers") @@ -52,7 +52,7 @@ func TestAllowListMiddleware_BlockedHost(t *testing.T) { t.Error("next handler should not be called for blocked host") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://evil.com/data") @@ -79,7 +79,7 @@ func TestAllowListMiddleware_BlockedPath(t *testing.T) { t.Error("next handler should not be called for blocked path") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://api.example.com/admin/users") @@ -106,7 +106,7 @@ func TestAllowListMiddleware_MissingTargetURL(t *testing.T) { t.Error("next handler should not be called when target URL is missing") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) // Not setting X-Connect-Target-URL header @@ -128,7 +128,7 @@ func TestAllowListMiddleware_CustomHeaderPrefix(t *testing.T) { w.WriteHeader(http.StatusOK) }) - middleware := NewAllowListMiddleware(allowList, "X-Custom", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Custom", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Custom-Target-URL", "https://api.example.com/test") @@ -146,7 +146,7 @@ func TestAllowListMiddleware_EmptyAllowList(t *testing.T) { t.Error("next handler should not be called for empty allow list") }) - middleware := NewAllowListMiddleware(nil, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(nil, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://api.example.com/test") @@ -168,7 +168,7 @@ func TestAllowListMiddleware_ResponseBody(t *testing.T) { t.Error("next handler should not be called for blocked host") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://evil.com/data") @@ -204,7 +204,7 @@ func TestAllowListMiddleware_InvalidTargetURL(t *testing.T) { t.Error("next handler should not be called for invalid URL") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) tests := []struct { name string @@ -252,7 +252,7 @@ func TestAllowListMiddleware_DoesNotLeakURLDetails(t *testing.T) { t.Error("next handler should not be called") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) // Test with sensitive-looking URL req := httptest.NewRequest(http.MethodGet, "/proxy", nil) @@ -298,7 +298,7 @@ func TestAllowListMiddleware_AllMethods(t *testing.T) { w.WriteHeader(http.StatusOK) }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(method, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://api.example.com/test") @@ -327,7 +327,7 @@ func TestAllowListMiddleware_ValidationPassed_DebugLog(t *testing.T) { nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req = req.WithContext(observability.WithTraceID(req.Context(), "trace-debug-123")) @@ -347,8 +347,8 @@ func TestAllowListMiddleware_ValidationPassed_DebugLog(t *testing.T) { if !strings.Contains(logOutput, `"trace_id":"trace-debug-123"`) { t.Errorf("expected trace_id in log, got: %s", logOutput) } - if !strings.Contains(logOutput, `"target_host":"api.example.com"`) { - t.Errorf("expected target_host in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"target_addr":"api.example.com"`) { + t.Errorf("expected target_addr in log, got: %s", logOutput) } } @@ -363,7 +363,7 @@ func TestAllowListMiddleware_ValidationFailed_HasTraceID(t *testing.T) { nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("next handler should not be called") }) - middleware := NewAllowListMiddleware(allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req = req.WithContext(observability.WithTraceID(req.Context(), "trace-fail-456")) @@ -380,8 +380,8 @@ func TestAllowListMiddleware_ValidationFailed_HasTraceID(t *testing.T) { if !strings.Contains(logOutput, `"trace_id":"trace-fail-456"`) { t.Errorf("expected trace_id in failure log, got: %s", logOutput) } - if !strings.Contains(logOutput, `"target_host":"evil.com"`) { - t.Errorf("expected target_host in failure log, got: %s", logOutput) + if !strings.Contains(logOutput, `"target_addr":"evil.com"`) { + t.Errorf("expected target_addr in failure log, got: %s", logOutput) } } @@ -401,7 +401,7 @@ func TestAllowListMiddleware_EmptyAllowListDeniesAll(t *testing.T) { t.Error("next handler should not be called") }) - middleware := NewAllowListMiddleware(tc.allowList, "X-Connect", nextHandler) + middleware := NewAllowListMiddleware(tc.allowList, "X-Connect", observability.TargetAddrModeHost, nextHandler) req := httptest.NewRequest(http.MethodGet, "/proxy", nil) req.Header.Set("X-Connect-Target-URL", "https://api.example.com/test")