Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 122 additions & 50 deletions module/pipeline_step_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (c *oauthTokenCache) getOrCreate(key string) *oauthCacheEntry {
type oauthCacheEntry struct {
mu sync.Mutex
accessToken string
instanceURL string // optional; populated when the token endpoint returns instance_url (Salesforce pattern)
expiry time.Time
sfGroup singleflight.Group
}
Expand All @@ -68,19 +69,28 @@ func (e *oauthCacheEntry) get() string {
return ""
}

// set stores a token with the given TTL.
func (e *oauthCacheEntry) set(token string, ttl time.Duration) {
// getInstanceURL returns the cached instance_url (may be empty if the token endpoint did not return one).
func (e *oauthCacheEntry) getInstanceURL() string {
e.mu.Lock()
defer e.mu.Unlock()
return e.instanceURL
}

// set stores a token and optional instance_url with the given TTL.
func (e *oauthCacheEntry) set(token, instanceURL string, ttl time.Duration) {
e.mu.Lock()
defer e.mu.Unlock()
e.accessToken = token
e.instanceURL = instanceURL
e.expiry = time.Now().Add(ttl)
}

// invalidate clears the cached token.
// invalidate clears the cached token and instance_url.
func (e *oauthCacheEntry) invalidate() {
e.mu.Lock()
defer e.mu.Unlock()
e.accessToken = ""
e.instanceURL = ""
e.expiry = time.Time{}
}

Expand Down Expand Up @@ -156,51 +166,86 @@ func NewHTTPCallStepFactory() StepFactory {
if authCfg, ok := config["auth"].(map[string]any); ok {
authType, _ := authCfg["type"].(string)
if authType == "oauth2_client_credentials" {
tokenURL, _ := authCfg["token_url"].(string)
if tokenURL == "" {
return nil, fmt.Errorf("http_call step %q: auth.token_url is required for oauth2_client_credentials", name)
}
clientID, _ := authCfg["client_id"].(string)
if clientID == "" {
return nil, fmt.Errorf("http_call step %q: auth.client_id is required for oauth2_client_credentials", name)
}
clientSecret, _ := authCfg["client_secret"].(string)
if clientSecret == "" {
return nil, fmt.Errorf("http_call step %q: auth.client_secret is required for oauth2_client_credentials", name)
}

var scopes []string
if raw, ok := authCfg["scopes"]; ok {
switch v := raw.(type) {
case []string:
scopes = v
case []any:
for _, s := range v {
if str, ok := s.(string); ok {
scopes = append(scopes, str)
}
}
}
cfg, oauthErr := buildOAuthConfig(name, "auth", authCfg)
if oauthErr != nil {
return nil, oauthErr
}
step.auth = cfg
step.oauthEntry = globalOAuthCache.getOrCreate(cfg.cacheKey)
}
}

// Cache key incorporates all credential fields so each distinct tenant/client
// gets its own isolated token cache entry.
cacheKey := tokenURL + "\x00" + clientID + "\x00" + clientSecret + "\x00" + strings.Join(scopes, " ")
step.auth = &oauthConfig{
tokenURL: tokenURL,
clientID: clientID,
clientSecret: clientSecret,
scopes: scopes,
cacheKey: cacheKey,
}
step.oauthEntry = globalOAuthCache.getOrCreate(cacheKey)
// Support top-level "oauth2" key as an alternative to "auth" with type=oauth2_client_credentials.
// This follows the syntax proposed in the issue and is more idiomatic for Salesforce-style configs:
// oauth2:
// grant_type: client_credentials (optional, defaults to client_credentials)
// token_url: "..."
// client_id: "..."
// client_secret: "..."
// scopes: ["api"]
// Note: if the "auth" block is also present, it takes precedence and "oauth2" is ignored.
if oauth2Cfg, ok := config["oauth2"].(map[string]any); ok && step.auth == nil {
grantType, _ := oauth2Cfg["grant_type"].(string)
if grantType == "" {
grantType = "client_credentials"
}
if grantType != "client_credentials" {
return nil, fmt.Errorf("http_call step %q: oauth2.grant_type must be 'client_credentials'", name)
}
cfg, oauthErr := buildOAuthConfig(name, "oauth2", oauth2Cfg)
if oauthErr != nil {
return nil, oauthErr
}
step.auth = cfg
step.oauthEntry = globalOAuthCache.getOrCreate(cfg.cacheKey)
}

return step, nil
}
}

// buildOAuthConfig parses OAuth2 client_credentials fields from a config map and returns an
// oauthConfig. The prefix parameter ("auth" or "oauth2") is used in error messages.
func buildOAuthConfig(stepName, prefix string, cfg map[string]any) (*oauthConfig, error) {
tokenURL, _ := cfg["token_url"].(string)
if tokenURL == "" {
return nil, fmt.Errorf("http_call step %q: %s.token_url is required", stepName, prefix)
}
clientID, _ := cfg["client_id"].(string)
if clientID == "" {
return nil, fmt.Errorf("http_call step %q: %s.client_id is required", stepName, prefix)
}
clientSecret, _ := cfg["client_secret"].(string)
if clientSecret == "" {
return nil, fmt.Errorf("http_call step %q: %s.client_secret is required", stepName, prefix)
}

var scopes []string
if raw, ok := cfg["scopes"]; ok {
switch v := raw.(type) {
case []string:
scopes = v
case []any:
for _, s := range v {
if str, ok := s.(string); ok {
scopes = append(scopes, str)
}
}
}
}

// Cache key incorporates all credential fields so each distinct tenant/client
// gets its own isolated token cache entry.
cacheKey := tokenURL + "\x00" + clientID + "\x00" + clientSecret + "\x00" + strings.Join(scopes, " ")
return &oauthConfig{
tokenURL: tokenURL,
clientID: clientID,
clientSecret: clientSecret,
scopes: scopes,
cacheKey: cacheKey,
}, nil
}

// Name returns the step name.
func (s *HTTPCallStep) Name() string { return s.name }

Expand Down Expand Up @@ -243,6 +288,7 @@ func (s *HTTPCallStep) doFetchToken(ctx context.Context) (string, error) {
AccessToken string `json:"access_token"` //nolint:gosec // G117: parsing OAuth2 token response, not a secret exposure
ExpiresIn float64 `json:"expires_in"`
TokenType string `json:"token_type"`
InstanceURL string `json:"instance_url"` // Salesforce pattern: base URL for subsequent API calls
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return "", fmt.Errorf("http_call step %q: failed to parse token response: %w", s.name, err)
Expand All @@ -259,7 +305,7 @@ func (s *HTTPCallStep) doFetchToken(ctx context.Context) (string, error) {
if ttl > 10*time.Second {
ttl -= 10 * time.Second
}
s.oauthEntry.set(tokenResp.AccessToken, ttl)
s.oauthEntry.set(tokenResp.AccessToken, tokenResp.InstanceURL, ttl)

return tokenResp.AccessToken, nil
}
Expand Down Expand Up @@ -392,6 +438,22 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
ctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()

// Obtain OAuth2 bearer token first so that instance_url is available for URL template resolution.
var bearerToken string
var err error
if s.auth != nil {
bearerToken, err = s.getToken(ctx)
if err != nil {
return nil, err
}
// Inject instance_url into the pipeline context so URL/header templates can reference it
// as {{ .instance_url }}. This is a Salesforce pattern where the token endpoint returns the
// org-specific base URL alongside the access token.
if instanceURL := s.oauthEntry.getInstanceURL(); instanceURL != "" {
pc.Current["instance_url"] = instanceURL
}
}

// Resolve URL template
resolvedURL, err := s.tmpl.Resolve(s.url, pc)
if err != nil {
Expand All @@ -403,15 +465,6 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
return nil, err
}

// Obtain OAuth2 bearer token if auth is configured
var bearerToken string
if s.auth != nil {
bearerToken, err = s.getToken(ctx)
if err != nil {
return nil, err
}
}

req, err := s.buildRequest(ctx, resolvedURL, bodyReader, rawBody, pc, bearerToken)
if err != nil {
return nil, err
Expand All @@ -438,11 +491,22 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
return nil, tokenErr
}

// After a token refresh, instance_url may have changed (Salesforce can rotate it).
// Re-inject it into pc.Current and re-resolve the URL template so the retry
// hits the correct host.
if instanceURL := s.oauthEntry.getInstanceURL(); instanceURL != "" {
pc.Current["instance_url"] = instanceURL
}
retryURL, resolveErr := s.tmpl.Resolve(s.url, pc)
if resolveErr != nil {
return nil, fmt.Errorf("http_call step %q: failed to resolve url for retry: %w", s.name, resolveErr)
}

retryBody, rawBody2, buildErr := s.buildBodyReader(pc)
if buildErr != nil {
return nil, buildErr
}
retryReq, buildErr := s.buildRequest(ctx, resolvedURL, retryBody, rawBody2, pc, newToken)
retryReq, buildErr := s.buildRequest(ctx, retryURL, retryBody, rawBody2, pc, newToken)
if buildErr != nil {
return nil, buildErr
}
Expand All @@ -459,13 +523,21 @@ func (s *HTTPCallStep) Execute(ctx context.Context, pc *PipelineContext) (*StepR
}

output := parseHTTPResponse(retryResp, respBody)
if instanceURL := s.oauthEntry.getInstanceURL(); instanceURL != "" {
output["instance_url"] = instanceURL
}
Comment on lines 525 to +528
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 401-retry path now returns instance_url from the refreshed token, but the retry request is still sent to the pre-refresh resolvedURL and uses whatever pc.Current contained before refresh. If instance_url changes between token responses (Salesforce can), the retry may hit the wrong host. After doFetchToken on 401, refresh pc.Current["instance_url"] from the cache and re-resolve the URL/template-dependent fields before building the retry request.

Copilot uses AI. Check for mistakes.
if retryResp.StatusCode >= 400 {
return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, retryResp.StatusCode, string(respBody))
}
return &StepResult{Output: output}, nil
}

output := parseHTTPResponse(resp, respBody)
if s.auth != nil {
if instanceURL := s.oauthEntry.getInstanceURL(); instanceURL != "" {
output["instance_url"] = instanceURL
}
}

if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http_call step %q: HTTP %d: %s", s.name, resp.StatusCode, string(respBody))
Expand Down
Loading
Loading