diff --git a/sentinelone/sdk.go b/sentinelone/sdk.go index 6ee4ea1..e58b725 100644 --- a/sentinelone/sdk.go +++ b/sentinelone/sdk.go @@ -36,37 +36,105 @@ type SentinelOnePagedData struct { // GetFromAPI retrieves data from the API based on the provided options func (c *SentinelOneClient) GetFromAPI(ctx context.Context, endpoint string, opts url.Values) (*SentinelOnePagedData, error) { - req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s%s", c.baseURL, endpoint), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request %q: %v", req.URL.String(), err) - } + const maxRetryDuration = 1 * time.Minute + const initialDelay = 1 * time.Second + const maxDelay = 10 * time.Second - // Add query parameters - req.URL.RawQuery = opts.Encode() + startTime := time.Now() + delay := initialDelay + reqURL := fmt.Sprintf("%s%s", c.baseURL, endpoint) - // Add authentication - req.Header.Set("Authorization", "Bearer "+c.apiToken) - req.Header.Set("Content-Type", "application/json") + for { + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request %q: %v", reqURL, err) + } - resp, err := c.httpSentinelOneClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request %q: %v", req.URL.String(), err) - } - defer resp.Body.Close() + // Add query parameters + req.URL.RawQuery = opts.Encode() + + // Add authentication + req.Header.Set("Authorization", "Bearer "+c.apiToken) + req.Header.Set("Content-Type", "application/json") - if resp.StatusCode != http.StatusOK { - // Read the response body for an error message - body, err := io.ReadAll(resp.Body) + resp, err := c.httpSentinelOneClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to read response body %q: %v", req.URL.String(), err) + // Check if we should retry on timeout or connection errors + if time.Since(startTime) < maxRetryDuration && isRetriableError(err) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context cancelled during retry: %v", err) + case <-time.After(delay): + delay = delay * 2 + if delay > maxDelay { + delay = maxDelay + } + continue + } + } + return nil, fmt.Errorf("failed to execute request %q: %v", req.URL.String(), err) + } + + // Read the body once and handle the response + body, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + + if readErr != nil { + if time.Since(startTime) < maxRetryDuration { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context cancelled during retry: %v", readErr) + case <-time.After(delay): + delay = delay * 2 + if delay > maxDelay { + delay = maxDelay + } + continue + } + } + return nil, fmt.Errorf("failed to read response body %q: %v", req.URL.String(), readErr) } - return nil, fmt.Errorf("unexpected status code %d for %q: %s", resp.StatusCode, req.URL.String(), string(body)) - } - var result SentinelOnePagedData - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode response %q: %v", req.URL.String(), err) + // Check for retryable status codes (503, 504) + if resp.StatusCode == http.StatusServiceUnavailable || resp.StatusCode == http.StatusGatewayTimeout { + if time.Since(startTime) < maxRetryDuration { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context cancelled during retry for status %d: %s", resp.StatusCode, string(body)) + case <-time.After(delay): + delay = delay * 2 + if delay > maxDelay { + delay = maxDelay + } + continue + } + } + // Retries exhausted + return nil, fmt.Errorf("unexpected status code %d for %q after retries: %s", resp.StatusCode, req.URL.String(), string(body)) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code %d for %q: %s", resp.StatusCode, req.URL.String(), string(body)) + } + + var result SentinelOnePagedData + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to decode response %q: %v", req.URL.String(), err) + } + + return &result, nil } +} - return &result, nil +// isRetriableError checks if an error is retriable (timeout or temporary network error) +func isRetriableError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "context deadline exceeded") || + strings.Contains(errStr, "Client.Timeout") || + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "connection reset") }