diff --git a/pool/result_pool.go b/pool/result_pool.go index f73a772..21cdab1 100644 --- a/pool/result_pool.go +++ b/pool/result_pool.go @@ -129,14 +129,11 @@ func (r *resultAggregator[T]) collect(collectErrored bool) []T { return r.results } - filtered := r.results[:0] + filtered := r.results sort.Ints(r.errored) - for i, e := range r.errored { - if i == 0 { - filtered = append(filtered, r.results[:e]...) - } else { - filtered = append(filtered, r.results[r.errored[i-1]+1:e]...) - } + for i := range r.errored { + e := r.errored[len(r.errored)-1-i] + filtered = append(filtered[:e], filtered[e+1:]...) } return filtered } diff --git a/pool/result_pool_test.go b/pool/result_pool_test.go index 69b9de4..209d4f5 100644 --- a/pool/result_pool_test.go +++ b/pool/result_pool_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "errors" "fmt" "math/rand" "strconv" @@ -83,6 +84,25 @@ func TestResultGroup(t *testing.T) { require.Equal(t, results, got) }) + t.Run("all results collected", func(t *testing.T) { + t.Parallel() + p := pool.NewWithResults[int]().WithErrors() + want := []int{1, 3} + + p.Go(func() (int, error) { + return 1, nil + }) + p.Go(func() (int, error) { + return 2, errors.New("an error") + }) + p.Go(func() (int, error) { + return 3, nil + }) + + got, _ := p.Wait() + require.Equal(t, want, got) + }) + t.Run("limit", func(t *testing.T) { t.Parallel() for _, maxGoroutines := range []int{1, 10, 100} {