From 460960b00cc800a0f6c72f0cf92546c3cb6e892e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 02:12:41 +0000 Subject: [PATCH 1/3] Initial plan From c7251952139edca5708ef66f9fd0ede5b8b6b4c8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 02:19:13 +0000 Subject: [PATCH 2/3] Add OAuth2 client_credentials support to step.http_call Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- cmd/wfctl/type_registry.go | 2 +- module/pipeline_step_http_call.go | 298 +++++++++++++++-- module/pipeline_step_http_call_test.go | 436 +++++++++++++++++++++++++ 3 files changed, 699 insertions(+), 37 deletions(-) create mode 100644 module/pipeline_step_http_call_test.go diff --git a/cmd/wfctl/type_registry.go b/cmd/wfctl/type_registry.go index d9bb8f12..f16754ed 100644 --- a/cmd/wfctl/type_registry.go +++ b/cmd/wfctl/type_registry.go @@ -548,7 +548,7 @@ func KnownStepTypes() map[string]StepTypeInfo { "step.http_call": { Type: "step.http_call", Plugin: "pipelinesteps", - ConfigKeys: []string{"url", "method", "headers", "body", "timeout"}, + ConfigKeys: []string{"url", "method", "headers", "body", "timeout", "auth"}, }, "step.request_parse": { Type: "step.request_parse", diff --git a/module/pipeline_step_http_call.go b/module/pipeline_step_http_call.go index cb6e5688..429bd4f9 100644 --- a/module/pipeline_step_http_call.go +++ b/module/pipeline_step_http_call.go @@ -7,27 +7,74 @@ import ( "fmt" "io" "net/http" + "net/url" + "strings" + "sync" "time" "github.com/CrisisTextLine/modular" ) +// oauthConfig holds OAuth2 client_credentials configuration. +type oauthConfig struct { + tokenURL string + clientID string + clientSecret string + scopes []string +} + +// tokenCache holds a cached OAuth2 access token and its expiry. +type tokenCache struct { + mu sync.Mutex + accessToken string + expiry time.Time +} + +// get returns the cached token if it is still valid, or empty string. +func (c *tokenCache) get() string { + c.mu.Lock() + defer c.mu.Unlock() + if c.accessToken != "" && time.Now().Before(c.expiry) { + return c.accessToken + } + return "" +} + +// set stores a token with the given TTL. +func (c *tokenCache) set(token string, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.accessToken = token + c.expiry = time.Now().Add(ttl) +} + +// invalidate clears the cached token. +func (c *tokenCache) invalidate() { + c.mu.Lock() + defer c.mu.Unlock() + c.accessToken = "" + c.expiry = time.Time{} +} + // HTTPCallStep makes an HTTP request as a pipeline step. type HTTPCallStep struct { - name string - url string - method string - headers map[string]string - body map[string]any - timeout time.Duration - tmpl *TemplateEngine + name string + url string + method string + headers map[string]string + body map[string]any + timeout time.Duration + tmpl *TemplateEngine + auth *oauthConfig + tokenCache *tokenCache + httpClient *http.Client } // NewHTTPCallStepFactory returns a StepFactory that creates HTTPCallStep instances. func NewHTTPCallStepFactory() StepFactory { return func(name string, config map[string]any, _ modular.Application) (PipelineStep, error) { - url, _ := config["url"].(string) - if url == "" { + rawURL, _ := config["url"].(string) + if rawURL == "" { return nil, fmt.Errorf("http_call step %q: 'url' is required", name) } @@ -37,11 +84,12 @@ func NewHTTPCallStepFactory() StepFactory { } step := &HTTPCallStep{ - name: name, - url: url, - method: method, - timeout: 30 * time.Second, - tmpl: NewTemplateEngine(), + name: name, + url: rawURL, + method: method, + timeout: 30 * time.Second, + tmpl: NewTemplateEngine(), + httpClient: http.DefaultClient, } if headers, ok := config["headers"].(map[string]any); ok { @@ -63,6 +111,46 @@ func NewHTTPCallStepFactory() StepFactory { } } + if authCfg, ok := config["auth"].(map[string]any); ok { + authType, _ := authCfg["type"].(string) + if authType == "oauth2_client_credentials" { + tokenURL, _ := authCfg["token_url"].(string) + if tokenURL == "" { + return nil, fmt.Errorf("http_call step %q: auth.token_url is required for oauth2_client_credentials", name) + } + clientID, _ := authCfg["client_id"].(string) + if clientID == "" { + return nil, fmt.Errorf("http_call step %q: auth.client_id is required for oauth2_client_credentials", name) + } + clientSecret, _ := authCfg["client_secret"].(string) + if clientSecret == "" { + return nil, fmt.Errorf("http_call step %q: auth.client_secret is required for oauth2_client_credentials", name) + } + + var scopes []string + if raw, ok := authCfg["scopes"]; ok { + switch v := raw.(type) { + case []string: + scopes = v + case []any: + for _, s := range v { + if str, ok := s.(string); ok { + scopes = append(scopes, str) + } + } + } + } + + step.auth = &oauthConfig{ + tokenURL: tokenURL, + clientID: clientID, + clientSecret: clientSecret, + scopes: scopes, + } + step.tokenCache = &tokenCache{} + } + } + return step, nil } } @@ -70,18 +158,74 @@ func NewHTTPCallStepFactory() StepFactory { // Name returns the step name. func (s *HTTPCallStep) Name() string { return s.name } -// Execute performs the HTTP request and returns the response. -func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { - ctx, cancel := context.WithTimeout(ctx, s.timeout) - defer cancel() +// fetchToken obtains a new OAuth2 access token using client_credentials grant. +func (s *HTTPCallStep) fetchToken(ctx context.Context) (string, error) { + params := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {s.auth.clientID}, + "client_secret": {s.auth.clientSecret}, + } + if len(s.auth.scopes) > 0 { + params.Set("scope", strings.Join(s.auth.scopes, " ")) + } - // Resolve URL template - resolvedURL, err := s.tmpl.Resolve(s.url, pc) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.auth.tokenURL, + strings.NewReader(params.Encode())) if err != nil { - return nil, fmt.Errorf("http_call step %q: failed to resolve url: %w", s.name, err) + return "", fmt.Errorf("http_call step %q: failed to create token request: %w", s.name, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := s.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("http_call step %q: token request failed: %w", s.name, err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("http_call step %q: failed to read token response: %w", s.name, err) } - var bodyReader io.Reader + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("http_call step %q: token endpoint returned HTTP %d: %s", s.name, resp.StatusCode, string(body)) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` //nolint:gosec // G117: parsing OAuth2 token response, not a secret exposure + ExpiresIn float64 `json:"expires_in"` + TokenType string `json:"token_type"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("http_call step %q: failed to parse token response: %w", s.name, err) + } + if tokenResp.AccessToken == "" { + return "", fmt.Errorf("http_call step %q: token response missing access_token", s.name) + } + + ttl := time.Duration(tokenResp.ExpiresIn) * time.Second + if ttl <= 0 { + ttl = 3600 * time.Second + } + // Subtract a small buffer to avoid using a token that is about to expire + if ttl > 10*time.Second { + ttl -= 10 * time.Second + } + s.tokenCache.set(tokenResp.AccessToken, ttl) + + return tokenResp.AccessToken, nil +} + +// getToken returns a valid OAuth2 token, fetching one if the cache is empty or expired. +func (s *HTTPCallStep) getToken(ctx context.Context) (string, error) { + if token := s.tokenCache.get(); token != "" { + return token, nil + } + return s.fetchToken(ctx) +} + +// buildBodyReader constructs the request body reader from the step configuration. +func (s *HTTPCallStep) buildBodyReader(pc *PipelineContext) (io.Reader, error) { if s.body != nil { resolvedBody, resolveErr := s.tmpl.ResolveMap(s.body, pc) if resolveErr != nil { @@ -91,15 +235,20 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR if marshalErr != nil { return nil, fmt.Errorf("http_call step %q: failed to marshal body: %w", s.name, marshalErr) } - bodyReader = bytes.NewReader(data) - } else if s.method != "GET" && s.method != "HEAD" { + return bytes.NewReader(data), nil + } + if s.method != "GET" && s.method != "HEAD" { data, marshalErr := json.Marshal(pc.Current) if marshalErr != nil { return nil, fmt.Errorf("http_call step %q: failed to marshal current data: %w", s.name, marshalErr) } - bodyReader = bytes.NewReader(data) + return bytes.NewReader(data), nil } + return nil, nil +} +// buildRequest constructs the HTTP request with resolved headers and optional bearer token. +func (s *HTTPCallStep) buildRequest(ctx context.Context, resolvedURL string, bodyReader io.Reader, pc *PipelineContext, bearerToken string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, s.method, resolvedURL, bodyReader) if err != nil { return nil, fmt.Errorf("http_call step %q: failed to create request: %w", s.name, err) @@ -116,19 +265,15 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR req.Header.Set(k, resolved) } } - - resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: SSRF via taint analysis - if err != nil { - return nil, fmt.Errorf("http_call step %q: request failed: %w", s.name, err) + if bearerToken != "" { + req.Header.Set("Authorization", "Bearer "+bearerToken) } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("http_call step %q: failed to read response: %w", s.name, err) - } + return req, nil +} - // Build response headers map +// parseResponse converts an HTTP response into a StepResult output map. +func parseHTTPResponse(resp *http.Response, respBody []byte) map[string]any { respHeaders := make(map[string]any, len(resp.Header)) for k, v := range resp.Header { if len(v) == 1 { @@ -148,7 +293,6 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR "headers": respHeaders, } - // Try to parse response as JSON var jsonResp any if json.Unmarshal(respBody, &jsonResp) == nil { output["body"] = jsonResp @@ -156,6 +300,88 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR output["body"] = string(respBody) } + return output +} + +// Execute performs the HTTP request and returns the response. +func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepResult, error) { + ctx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + // Resolve URL template + resolvedURL, err := s.tmpl.Resolve(s.url, pc) + if err != nil { + return nil, fmt.Errorf("http_call step %q: failed to resolve url: %w", s.name, err) + } + + bodyReader, err := s.buildBodyReader(pc) + if err != nil { + return nil, err + } + + // Obtain OAuth2 bearer token if auth is configured + var bearerToken string + if s.auth != nil { + bearerToken, err = s.getToken(ctx) + if err != nil { + return nil, err + } + } + + req, err := s.buildRequest(ctx, resolvedURL, bodyReader, pc, bearerToken) + if err != nil { + return nil, err + } + + resp, err := s.httpClient.Do(req) //nolint:gosec // G107: URL is user-configured + if err != nil { + return nil, fmt.Errorf("http_call step %q: request failed: %w", s.name, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("http_call step %q: failed to read response: %w", s.name, err) + } + + // On 401, invalidate token cache and retry once with a fresh token + if resp.StatusCode == http.StatusUnauthorized && s.auth != nil { + s.tokenCache.invalidate() + + newToken, tokenErr := s.fetchToken(ctx) + if tokenErr != nil { + return nil, tokenErr + } + + retryBody, buildErr := s.buildBodyReader(pc) + if buildErr != nil { + return nil, buildErr + } + retryReq, buildErr := s.buildRequest(ctx, resolvedURL, retryBody, pc, newToken) + if buildErr != nil { + return nil, buildErr + } + + retryResp, doErr := s.httpClient.Do(retryReq) //nolint:gosec // G107: URL is user-configured + if doErr != nil { + return nil, fmt.Errorf("http_call step %q: retry request failed: %w", s.name, doErr) + } + defer retryResp.Body.Close() + + respBody, err = io.ReadAll(retryResp.Body) + if err != nil { + return nil, fmt.Errorf("http_call step %q: failed to read retry response: %w", s.name, err) + } + + output := parseHTTPResponse(retryResp, respBody) + if retryResp.StatusCode >= 400 { + return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, retryResp.StatusCode, string(respBody)) + } + return &StepResult{Output: output}, nil + } + + output := parseHTTPResponse(resp, respBody) + if resp.StatusCode >= 400 { return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, resp.StatusCode, string(respBody)) } diff --git a/module/pipeline_step_http_call_test.go b/module/pipeline_step_http_call_test.go new file mode 100644 index 00000000..651efb92 --- /dev/null +++ b/module/pipeline_step_http_call_test.go @@ -0,0 +1,436 @@ +package module + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestHTTPCallStep_BasicGET(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"hello":"world"}`)) + })) + defer srv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("get-test", map[string]any{ + "url": srv.URL + "/resource", + "method": "GET", + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + // inject test client + step.(*HTTPCallStep).httpClient = srv.Client() + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["status_code"] != http.StatusOK { + t.Errorf("expected status_code 200, got %v", result.Output["status_code"]) + } + body, ok := result.Output["body"].(map[string]any) + if !ok { + t.Fatalf("expected JSON body, got %T", result.Output["body"]) + } + if body["hello"] != "world" { + t.Errorf("expected hello=world, got %v", body["hello"]) + } +} + +func TestHTTPCallStep_MissingURL(t *testing.T) { + factory := NewHTTPCallStepFactory() + _, err := factory("no-url", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error for missing url") + } + if !strings.Contains(err.Error(), "'url' is required") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestHTTPCallStep_ErrorResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`bad request`)) + })) + defer srv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("err-test", map[string]any{"url": srv.URL}, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + step.(*HTTPCallStep).httpClient = srv.Client() + + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + if err == nil { + t.Fatal("expected error for 400 response") + } + if !strings.Contains(err.Error(), "HTTP 400") { + t.Errorf("unexpected error: %v", err) + } +} + +// TestHTTPCallStep_OAuth2_FetchesToken verifies that a bearer token is obtained and sent. +func TestHTTPCallStep_OAuth2_FetchesToken(t *testing.T) { + var tokenRequests int32 + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&tokenRequests, 1) + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + _ = r.ParseForm() + if r.FormValue("grant_type") != "client_credentials" { + t.Errorf("expected grant_type=client_credentials, got %q", r.FormValue("grant_type")) + } + if r.FormValue("client_id") != "my-client" { + t.Errorf("expected client_id=my-client, got %q", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "my-secret" { + t.Errorf("expected client_secret=my-secret, got %q", r.FormValue("client_secret")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-access-token", + "expires_in": 3600, + "token_type": "Bearer", + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test-access-token" { + t.Errorf("expected Authorization: Bearer test-access-token, got %q", auth) + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer apiSrv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("oauth-test", map[string]any{ + "url": apiSrv.URL + "/data", + "method": "GET", + "auth": map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "my-client", + "client_secret": "my-secret", + }, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + // Use a shared transport so both servers are reachable with a single client + step.(*HTTPCallStep).httpClient = &http.Client{} + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["status_code"] != http.StatusOK { + t.Errorf("expected 200, got %v", result.Output["status_code"]) + } + if atomic.LoadInt32(&tokenRequests) != 1 { + t.Errorf("expected 1 token request, got %d", atomic.LoadInt32(&tokenRequests)) + } +} + +// TestHTTPCallStep_OAuth2_TokenCached verifies that a second call reuses the cached token. +func TestHTTPCallStep_OAuth2_TokenCached(t *testing.T) { + var tokenRequests int32 + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&tokenRequests, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "cached-token", + "expires_in": 3600, + "token_type": "Bearer", + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer apiSrv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("cache-test", map[string]any{ + "url": apiSrv.URL, + "method": "GET", + "auth": map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "cid", + "client_secret": "csec", + }, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + step.(*HTTPCallStep).httpClient = &http.Client{} + + pc := NewPipelineContext(nil, nil) + + // First call – token is fetched + if _, err := step.Execute(context.Background(), pc); err != nil { + t.Fatalf("first execute error: %v", err) + } + // Second call – token is reused from cache + if _, err := step.Execute(context.Background(), pc); err != nil { + t.Fatalf("second execute error: %v", err) + } + + if atomic.LoadInt32(&tokenRequests) != 1 { + t.Errorf("expected token to be fetched only once, got %d requests", atomic.LoadInt32(&tokenRequests)) + } +} + +// TestHTTPCallStep_OAuth2_Retry401 verifies that a 401 triggers token invalidation and retry. +func TestHTTPCallStep_OAuth2_Retry401(t *testing.T) { + var tokenRequests int32 + var apiRequests int32 + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&tokenRequests, 1) + token := "token-v1" + if n > 1 { + token = "token-v2" + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": token, + "expires_in": 3600, + "token_type": "Bearer", + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&apiRequests, 1) + if n == 1 { + // First call: return 401 + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`unauthorized`)) + return + } + // Retry: verify fresh token + if r.Header.Get("Authorization") != "Bearer token-v2" { + t.Errorf("expected Bearer token-v2, got %q", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer apiSrv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("retry-test", map[string]any{ + "url": apiSrv.URL + "/api", + "method": "GET", + "auth": map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "cid", + "client_secret": "csec", + }, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + step.(*HTTPCallStep).httpClient = &http.Client{} + + pc := NewPipelineContext(nil, nil) + result, err := step.Execute(context.Background(), pc) + if err != nil { + t.Fatalf("execute error: %v", err) + } + + if result.Output["status_code"] != http.StatusOK { + t.Errorf("expected 200 after retry, got %v", result.Output["status_code"]) + } + if atomic.LoadInt32(&tokenRequests) != 2 { + t.Errorf("expected 2 token requests, got %d", atomic.LoadInt32(&tokenRequests)) + } + if atomic.LoadInt32(&apiRequests) != 2 { + t.Errorf("expected 2 API requests, got %d", atomic.LoadInt32(&apiRequests)) + } +} + +// TestHTTPCallStep_OAuth2_Scopes verifies that scopes are sent in the token request. +func TestHTTPCallStep_OAuth2_Scopes(t *testing.T) { + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + scope := r.FormValue("scope") + if scope != "read write" { + t.Errorf("expected scope='read write', got %q", scope) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "scoped-token", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{}`)) + })) + defer apiSrv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("scope-test", map[string]any{ + "url": apiSrv.URL, + "method": "GET", + "auth": map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "cid", + "client_secret": "csec", + "scopes": []any{"read", "write"}, + }, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + step.(*HTTPCallStep).httpClient = &http.Client{} + + pc := NewPipelineContext(nil, nil) + if _, err := step.Execute(context.Background(), pc); err != nil { + t.Fatalf("execute error: %v", err) + } +} + +// TestHTTPCallStep_OAuth2_MissingFields verifies that missing auth fields produce errors. +func TestHTTPCallStep_OAuth2_MissingFields(t *testing.T) { + factory := NewHTTPCallStepFactory() + + tests := []struct { + name string + auth map[string]any + errMsg string + }{ + { + name: "missing token_url", + auth: map[string]any{ + "type": "oauth2_client_credentials", + "client_id": "cid", + "client_secret": "csec", + }, + errMsg: "auth.token_url is required", + }, + { + name: "missing client_id", + auth: map[string]any{ + "type": "oauth2_client_credentials", + "token_url": "http://example.com/token", + "client_secret": "csec", + }, + errMsg: "auth.client_id is required", + }, + { + name: "missing client_secret", + auth: map[string]any{ + "type": "oauth2_client_credentials", + "token_url": "http://example.com/token", + "client_id": "cid", + }, + errMsg: "auth.client_secret is required", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := factory("test", map[string]any{ + "url": "http://example.com/api", + "auth": tc.auth, + }, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), tc.errMsg) { + t.Errorf("expected %q in error, got: %v", tc.errMsg, err) + } + }) + } +} + +// TestHTTPCallStep_OAuth2_TokenExpiry verifies that an expired token is refreshed. +func TestHTTPCallStep_OAuth2_TokenExpiry(t *testing.T) { + var tokenRequests int32 + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&tokenRequests, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "short-lived-token", + "expires_in": 1, // 1 second TTL (minus 10s buffer => immediately invalid) + "token_type": "Bearer", + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer apiSrv.Close() + + factory := NewHTTPCallStepFactory() + step, err := factory("expiry-test", map[string]any{ + "url": apiSrv.URL, + "method": "GET", + "auth": map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "cid", + "client_secret": "csec", + }, + }, nil) + if err != nil { + t.Fatalf("factory error: %v", err) + } + step.(*HTTPCallStep).httpClient = &http.Client{} + + pc := NewPipelineContext(nil, nil) + + // First call fetches a token with TTL=1s; after subtracting 10s buffer it expires immediately. + if _, err := step.Execute(context.Background(), pc); err != nil { + t.Fatalf("first execute error: %v", err) + } + // Force expiry: sleep briefly then call again; the cache should be empty. + time.Sleep(50 * time.Millisecond) + // Manually set the token expiry to the past to simulate expiration. + step.(*HTTPCallStep).tokenCache.mu.Lock() + step.(*HTTPCallStep).tokenCache.expiry = time.Now().Add(-time.Second) + step.(*HTTPCallStep).tokenCache.mu.Unlock() + + if _, err := step.Execute(context.Background(), pc); err != nil { + t.Fatalf("second execute error: %v", err) + } + + if atomic.LoadInt32(&tokenRequests) != 2 { + t.Errorf("expected token to be fetched twice (once per expired cache), got %d", atomic.LoadInt32(&tokenRequests)) + } +} From 06274fe307ead3b10245d6051661e4997dec0a9e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 02:55:22 +0000 Subject: [PATCH 3/3] Fix token cache per-instance, race condition, and multi-tenancy for OAuth2 http_call step Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- module/pipeline_step_http_call.go | 132 ++++++++++++++++++------- module/pipeline_step_http_call_test.go | 73 +++++++++++++- 2 files changed, 166 insertions(+), 39 deletions(-) diff --git a/module/pipeline_step_http_call.go b/module/pipeline_step_http_call.go index 429bd4f9..3a05b539 100644 --- a/module/pipeline_step_http_call.go +++ b/module/pipeline_step_http_call.go @@ -12,48 +12,85 @@ import ( "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/CrisisTextLine/modular" ) -// oauthConfig holds OAuth2 client_credentials configuration. -type oauthConfig struct { - tokenURL string - clientID string - clientSecret string - scopes []string +// globalOAuthCache is a process-wide registry of OAuth2 token cache entries, shared across all +// HTTPCallStep instances. Entries are keyed by a credential fingerprint (token URL + client ID + +// client secret + scopes), so each distinct set of credentials (i.e. each tenant) gets its own +// isolated entry. +var globalOAuthCache = &oauthTokenCache{ //nolint:gochecknoglobals // intentional process-wide cache + entries: make(map[string]*oauthCacheEntry), +} + +// oauthTokenCache is a registry of per-credential token cache entries. +type oauthTokenCache struct { + mu sync.RWMutex + entries map[string]*oauthCacheEntry } -// tokenCache holds a cached OAuth2 access token and its expiry. -type tokenCache struct { +// getOrCreate returns the existing cache entry for key, or creates and stores a new one. +func (c *oauthTokenCache) getOrCreate(key string) *oauthCacheEntry { + c.mu.RLock() + entry, ok := c.entries[key] + c.mu.RUnlock() + if ok { + return entry + } + c.mu.Lock() + defer c.mu.Unlock() + if entry, ok = c.entries[key]; ok { + return entry + } + entry = &oauthCacheEntry{} + c.entries[key] = entry + return entry +} + +// oauthCacheEntry holds a cached OAuth2 access token with expiry. A singleflight.Group is +// embedded to ensure at most one concurrent token fetch per credential set. +type oauthCacheEntry struct { mu sync.Mutex accessToken string expiry time.Time + sfGroup singleflight.Group } -// get returns the cached token if it is still valid, or empty string. -func (c *tokenCache) get() string { - c.mu.Lock() - defer c.mu.Unlock() - if c.accessToken != "" && time.Now().Before(c.expiry) { - return c.accessToken +// get returns the cached token if still valid, or an empty string. +func (e *oauthCacheEntry) get() string { + e.mu.Lock() + defer e.mu.Unlock() + if e.accessToken != "" && time.Now().Before(e.expiry) { + return e.accessToken } return "" } // set stores a token with the given TTL. -func (c *tokenCache) set(token string, ttl time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - c.accessToken = token - c.expiry = time.Now().Add(ttl) +func (e *oauthCacheEntry) set(token string, ttl time.Duration) { + e.mu.Lock() + defer e.mu.Unlock() + e.accessToken = token + e.expiry = time.Now().Add(ttl) } // invalidate clears the cached token. -func (c *tokenCache) invalidate() { - c.mu.Lock() - defer c.mu.Unlock() - c.accessToken = "" - c.expiry = time.Time{} +func (e *oauthCacheEntry) invalidate() { + e.mu.Lock() + defer e.mu.Unlock() + e.accessToken = "" + e.expiry = time.Time{} +} + +// oauthConfig holds OAuth2 client_credentials configuration. +type oauthConfig struct { + tokenURL string + clientID string + clientSecret string + scopes []string + cacheKey string // derived from credentials; used for per-tenant cache isolation } // HTTPCallStep makes an HTTP request as a pipeline step. @@ -66,8 +103,8 @@ type HTTPCallStep struct { timeout time.Duration tmpl *TemplateEngine auth *oauthConfig - tokenCache *tokenCache - httpClient *http.Client + oauthEntry *oauthCacheEntry // shared entry from globalOAuthCache; nil when no auth configured + httpClient *http.Client // timeout is enforced via the context passed to each request } // NewHTTPCallStepFactory returns a StepFactory that creates HTTPCallStep instances. @@ -141,13 +178,17 @@ func NewHTTPCallStepFactory() StepFactory { } } + // Cache key incorporates all credential fields so each distinct tenant/client + // gets its own isolated token cache entry. + cacheKey := tokenURL + "\x00" + clientID + "\x00" + clientSecret + "\x00" + strings.Join(scopes, " ") step.auth = &oauthConfig{ tokenURL: tokenURL, clientID: clientID, clientSecret: clientSecret, scopes: scopes, + cacheKey: cacheKey, } - step.tokenCache = &tokenCache{} + step.oauthEntry = globalOAuthCache.getOrCreate(cacheKey) } } @@ -158,8 +199,10 @@ func NewHTTPCallStepFactory() StepFactory { // Name returns the step name. func (s *HTTPCallStep) Name() string { return s.name } -// fetchToken obtains a new OAuth2 access token using client_credentials grant. -func (s *HTTPCallStep) fetchToken(ctx context.Context) (string, error) { +// doFetchToken performs the actual HTTP call to the token endpoint, caches the result, and returns +// the new access token. It is called either via getToken (through singleflight) or directly on +// the 401-retry path where an unconditional refresh is needed. +func (s *HTTPCallStep) doFetchToken(ctx context.Context) (string, error) { params := url.Values{ "grant_type": {"client_credentials"}, "client_id": {s.auth.clientID}, @@ -211,17 +254,33 @@ func (s *HTTPCallStep) fetchToken(ctx context.Context) (string, error) { if ttl > 10*time.Second { ttl -= 10 * time.Second } - s.tokenCache.set(tokenResp.AccessToken, ttl) + s.oauthEntry.set(tokenResp.AccessToken, ttl) return tokenResp.AccessToken, nil } -// getToken returns a valid OAuth2 token, fetching one if the cache is empty or expired. +// getToken returns a valid OAuth2 token from the shared cache. If the cache is empty or expired, +// a single network fetch is performed; concurrent callers for the same credential set are +// coalesced via singleflight so the token endpoint is called at most once. func (s *HTTPCallStep) getToken(ctx context.Context) (string, error) { - if token := s.tokenCache.get(); token != "" { + // Fast path: valid token already in the shared cache. + if token := s.oauthEntry.get(); token != "" { return token, nil } - return s.fetchToken(ctx) + + // Slow path: coalesce concurrent fetches so only one goroutine calls the token endpoint. + val, err, _ := s.oauthEntry.sfGroup.Do("fetch", func() (any, error) { + // Double-check inside the group so we don't fetch again if a concurrent goroutine + // already populated the cache while we were waiting. + if token := s.oauthEntry.get(); token != "" { + return token, nil + } + return s.doFetchToken(ctx) + }) + if err != nil { + return "", err + } + return val.(string), nil } // buildBodyReader constructs the request body reader from the step configuration. @@ -344,11 +403,12 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR return nil, fmt.Errorf("http_call step %q: failed to read response: %w", s.name, err) } - // On 401, invalidate token cache and retry once with a fresh token + // On 401, invalidate the shared cache and fetch a fresh token directly (bypassing + // singleflight so the refresh is not coalesced with an in-progress normal fetch). if resp.StatusCode == http.StatusUnauthorized && s.auth != nil { - s.tokenCache.invalidate() + s.oauthEntry.invalidate() - newToken, tokenErr := s.fetchToken(ctx) + newToken, tokenErr := s.doFetchToken(ctx) if tokenErr != nil { return nil, tokenErr } diff --git a/module/pipeline_step_http_call_test.go b/module/pipeline_step_http_call_test.go index 651efb92..e5c4a29f 100644 --- a/module/pipeline_step_http_call_test.go +++ b/module/pipeline_step_http_call_test.go @@ -3,6 +3,7 @@ package module import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -422,9 +423,9 @@ func TestHTTPCallStep_OAuth2_TokenExpiry(t *testing.T) { // Force expiry: sleep briefly then call again; the cache should be empty. time.Sleep(50 * time.Millisecond) // Manually set the token expiry to the past to simulate expiration. - step.(*HTTPCallStep).tokenCache.mu.Lock() - step.(*HTTPCallStep).tokenCache.expiry = time.Now().Add(-time.Second) - step.(*HTTPCallStep).tokenCache.mu.Unlock() + step.(*HTTPCallStep).oauthEntry.mu.Lock() + step.(*HTTPCallStep).oauthEntry.expiry = time.Now().Add(-time.Second) + step.(*HTTPCallStep).oauthEntry.mu.Unlock() if _, err := step.Execute(context.Background(), pc); err != nil { t.Fatalf("second execute error: %v", err) @@ -434,3 +435,69 @@ func TestHTTPCallStep_OAuth2_TokenExpiry(t *testing.T) { t.Errorf("expected token to be fetched twice (once per expired cache), got %d", atomic.LoadInt32(&tokenRequests)) } } + +// TestHTTPCallStep_OAuth2_ConcurrentFetch verifies that concurrent executions on different step +// instances sharing the same credentials only call the token endpoint once (singleflight). +func TestHTTPCallStep_OAuth2_ConcurrentFetch(t *testing.T) { + var tokenRequests int32 + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Small delay to allow multiple goroutines to pile up before the first response. + time.Sleep(20 * time.Millisecond) + atomic.AddInt32(&tokenRequests, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "shared-token", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + })) + defer apiSrv.Close() + + // Use a unique client_secret per test run to get a fresh global cache entry. + uniqueSecret := fmt.Sprintf("concurrent-secret-%d", time.Now().UnixNano()) + + factory := NewHTTPCallStepFactory() + authCfg := map[string]any{ + "type": "oauth2_client_credentials", + "token_url": tokenSrv.URL + "/token", + "client_id": "concurrent-cid", + "client_secret": uniqueSecret, + } + + const concurrency = 5 + errs := make(chan error, concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + step, err := factory("concurrent-test", map[string]any{ + "url": apiSrv.URL, + "method": "GET", + "auth": authCfg, + }, nil) + if err != nil { + errs <- err + return + } + step.(*HTTPCallStep).httpClient = &http.Client{} + pc := NewPipelineContext(nil, nil) + _, err = step.Execute(context.Background(), pc) + errs <- err + }() + } + + for i := 0; i < concurrency; i++ { + if err := <-errs; err != nil { + t.Errorf("goroutine error: %v", err) + } + } + + if n := atomic.LoadInt32(&tokenRequests); n != 1 { + t.Errorf("expected exactly 1 token request via singleflight, got %d", n) + } +}