Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 92 additions & 24 deletions sentinelone/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,105 @@ type SentinelOnePagedData struct {

// GetFromAPI retrieves data from the API based on the provided options
func (c *SentinelOneClient) GetFromAPI(ctx context.Context, endpoint string, opts url.Values) (*SentinelOnePagedData, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s%s", c.baseURL, endpoint), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request %q: %v", req.URL.String(), err)
}
const maxRetryDuration = 1 * time.Minute
const initialDelay = 1 * time.Second
const maxDelay = 10 * time.Second

// Add query parameters
req.URL.RawQuery = opts.Encode()
startTime := time.Now()
delay := initialDelay
reqURL := fmt.Sprintf("%s%s", c.baseURL, endpoint)

// Add authentication
req.Header.Set("Authorization", "Bearer "+c.apiToken)
req.Header.Set("Content-Type", "application/json")
for {
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request %q: %v", reqURL, err)
}

resp, err := c.httpSentinelOneClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request %q: %v", req.URL.String(), err)
}
defer resp.Body.Close()
// Add query parameters
req.URL.RawQuery = opts.Encode()

// Add authentication
req.Header.Set("Authorization", "Bearer "+c.apiToken)
req.Header.Set("Content-Type", "application/json")

if resp.StatusCode != http.StatusOK {
// Read the response body for an error message
body, err := io.ReadAll(resp.Body)
resp, err := c.httpSentinelOneClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to read response body %q: %v", req.URL.String(), err)
// Check if we should retry on timeout or connection errors
if time.Since(startTime) < maxRetryDuration && isRetriableError(err) {
select {
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled during retry: %v", err)
case <-time.After(delay):
delay = delay * 2
if delay > maxDelay {
delay = maxDelay
}
continue
}
}
return nil, fmt.Errorf("failed to execute request %q: %v", req.URL.String(), err)
}

// Read the body once and handle the response
body, readErr := io.ReadAll(resp.Body)
resp.Body.Close()

if readErr != nil {
if time.Since(startTime) < maxRetryDuration {
select {
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled during retry: %v", readErr)
case <-time.After(delay):
delay = delay * 2
if delay > maxDelay {
delay = maxDelay
}
continue
}
}
return nil, fmt.Errorf("failed to read response body %q: %v", req.URL.String(), readErr)
}
return nil, fmt.Errorf("unexpected status code %d for %q: %s", resp.StatusCode, req.URL.String(), string(body))
}

var result SentinelOnePagedData
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response %q: %v", req.URL.String(), err)
// Check for retryable status codes (503, 504)
if resp.StatusCode == http.StatusServiceUnavailable || resp.StatusCode == http.StatusGatewayTimeout {
if time.Since(startTime) < maxRetryDuration {
select {
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled during retry for status %d: %s", resp.StatusCode, string(body))
case <-time.After(delay):
delay = delay * 2
if delay > maxDelay {
delay = maxDelay
}
continue
}
}
// Retries exhausted
return nil, fmt.Errorf("unexpected status code %d for %q after retries: %s", resp.StatusCode, req.URL.String(), string(body))
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d for %q: %s", resp.StatusCode, req.URL.String(), string(body))
}

var result SentinelOnePagedData
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to decode response %q: %v", req.URL.String(), err)
}

return &result, nil
}
}

return &result, nil
// isRetriableError checks if an error is retriable (timeout or temporary network error)
func isRetriableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "context deadline exceeded") ||
strings.Contains(errStr, "Client.Timeout") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "connection reset")
}