diff --git a/cache/disk/disk.go b/cache/disk/disk.go index d250565..c601dc4 100644 --- a/cache/disk/disk.go +++ b/cache/disk/disk.go @@ -76,7 +76,9 @@ type diskCache struct { maxBlobSize int64 maxProxyBlobSize int64 accessLogger *log.Logger - containsQueue chan proxyCheck + + // Limit the number of simultaneous proxy Contains checks. + containsSem *semaphore.Weighted // Limit the number of simultaneous file removals. fileRemovalSem *semaphore.Weighted diff --git a/cache/disk/findmissing.go b/cache/disk/findmissing.go index d1db459..2c76c4b 100644 --- a/cache/disk/findmissing.go +++ b/cache/disk/findmissing.go @@ -9,10 +9,15 @@ import ( pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2" + "golang.org/x/sync/semaphore" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) +// maxConcurrentContainsChecks bounds the number of simultaneous proxy +// Contains lookups across all FindMissingBlobs requests. +const maxConcurrentContainsChecks = 512 + type proxyCheck struct { wg *sync.WaitGroup digest **pb.Digest @@ -109,26 +114,27 @@ func (c *diskCache) findMissingCasBlobsInternal(ctx context.Context, blobs []*pb continue } - // Adding to the containsQueue channel may have blocked on a previous iteration, - // so check to see if the context has cancelled. - select { - case <-ctx.Done(): + // Acquire blocks while maxConcurrentContainsChecks checks are + // in flight, and returns early if the context is cancelled. + if err := c.containsSem.Acquire(ctx, 1); err != nil { if cancelledDueToFailFast { return errMissingBlob } return errRequestCancelled - default: } wg.Add(1) - c.containsQueue <- proxyCheck{ + go func(req proxyCheck) { + defer c.containsSem.Release(1) + c.processContainsCheck(req) + }(proxyCheck{ wg: &wg, digest: &chunk[i], ctx: ctx, // When failFast is true, onProxyMiss will have been set to a function that // will cancel the context, causing the remaining proxyChecks to short-circuit. onProxyMiss: cancelContextForFailFast, - } + }) } } } @@ -206,43 +212,37 @@ func (c *diskCache) findMissingLocalCAS(blobs []*pb.Digest) int { return missing } -func (c *diskCache) containsWorker() { - var ok bool - for req := range c.containsQueue { - if req.ctx != nil { - select { - case <-req.ctx.Done(): - // Fast-fail if the context has already been cancelled. - c.accessLogger.Printf("GRPC CAS HEAD %s CANCELLED", (*req.digest).Hash) - req.wg.Done() - continue - default: - } +// processContainsCheck performs a single proxy Contains lookup and calls +// req.wg.Done exactly once. +func (c *diskCache) processContainsCheck(req proxyCheck) { + defer req.wg.Done() + + if req.ctx != nil { + select { + case <-req.ctx.Done(): + // Fast-fail if the context has already been cancelled. + c.accessLogger.Printf("GRPC CAS HEAD %s CANCELLED", (*req.digest).Hash) + return + default: } + } - ok, _ = c.proxy.Contains(req.ctx, cache.CAS, (*req.digest).Hash, (*req.digest).SizeBytes) - if ok { - c.accessLogger.Printf("GRPC CAS HEAD %s OK", (*req.digest).Hash) - // The blob exists on the proxy, remove it from the - // list of missing blobs. - *(req.digest) = nil - } else { - c.accessLogger.Printf("GRPC CAS HEAD %s NOT FOUND", (*req.digest).Hash) - if req.onProxyMiss != nil { - req.onProxyMiss() - } + ok, _ := c.proxy.Contains(req.ctx, cache.CAS, (*req.digest).Hash, (*req.digest).SizeBytes) + if ok { + c.accessLogger.Printf("GRPC CAS HEAD %s OK", (*req.digest).Hash) + // The blob exists on the proxy, remove it from the + // list of missing blobs. + *(req.digest) = nil + } else { + c.accessLogger.Printf("GRPC CAS HEAD %s NOT FOUND", (*req.digest).Hash) + if req.onProxyMiss != nil { + req.onProxyMiss() } - req.wg.Done() } } -func (c *diskCache) spawnContainsQueueWorkers() { - // TODO: make these configurable? - const queueSize = 2048 - const numWorkers = 512 - - c.containsQueue = make(chan proxyCheck, queueSize) - for i := 0; i < numWorkers; i++ { - go c.containsWorker() +func (c *diskCache) initContainsCheckLimiter() { + if c.containsSem == nil { + c.containsSem = semaphore.NewWeighted(maxConcurrentContainsChecks) } } diff --git a/cache/disk/findmissing_test.go b/cache/disk/findmissing_test.go index 3e24f5e..0bebe6b 100644 --- a/cache/disk/findmissing_test.go +++ b/cache/disk/findmissing_test.go @@ -8,10 +8,13 @@ import ( "io" "os" "sync" + "sync/atomic" "testing" + "time" "github.com/buchgr/bazel-remote/v2/cache" testutils "github.com/buchgr/bazel-remote/v2/utils" + "golang.org/x/sync/semaphore" "google.golang.org/protobuf/proto" pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2" @@ -104,49 +107,38 @@ func (p *testCWProxy) Contains(ctx context.Context, kind cache.EntryKind, hash s return false, -1 } -func TestContainsWorker(t *testing.T) { +func TestProcessContainsCheck(t *testing.T) { t.Parallel() tp := testCWProxy{blob: "9205adc12a2c8b65e7cd77918ff8e6e20f39bdd0b7fc4b984abfd690c79d80c1"} c := diskCache{ - accessLogger: testutils.NewSilentLogger(), - proxy: &tp, - containsQueue: make(chan proxyCheck, 2), + accessLogger: testutils.NewSilentLogger(), + proxy: &tp, } - // Spawn a single worker. - go c.containsWorker() - digests := []*pb.Digest{ // Expect this to be found in the proxy, and replaced with nil. {Hash: tp.blob, SizeBytes: 42}, // Expect this not to be found in the proxy, and left unchanged. {Hash: "423789fae66b9539c5622134c580700a154a15e355af4e3311a4e12ee0c9d243", SizeBytes: 43}, - } - if cap(c.containsQueue) != len(digests) { - t.Fatalf("Broken test setup, expected containsQueue capacity %d to match number of digests %d", - cap(c.containsQueue), len(digests)) + // Expect this to be left unchanged: its context is already + // cancelled, so the proxy must not be consulted at all. + {Hash: tp.blob, SizeBytes: 42}, } - var wg sync.WaitGroup + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() - for i := range digests { - wg.Add(1) - c.containsQueue <- proxyCheck{ - wg: &wg, - digest: &digests[i], - } - } - - // Wait for the worker to process each request. + var wg sync.WaitGroup + wg.Add(3) + c.processContainsCheck(proxyCheck{wg: &wg, digest: &digests[0]}) + c.processContainsCheck(proxyCheck{wg: &wg, digest: &digests[1]}) + c.processContainsCheck(proxyCheck{wg: &wg, digest: &digests[2], ctx: cancelledCtx}) wg.Wait() - // Allow the worker goroutine to finish. - close(c.containsQueue) - if digests[0] != nil { t.Error("Expected digests[0] to be found in the proxy and replaced by nil") } @@ -154,6 +146,10 @@ func TestContainsWorker(t *testing.T) { if digests[1] == nil { t.Error("Expected digests[1] to not be found in the proxy and left as-is") } + + if digests[2] == nil { + t.Error("Expected digests[2] to be skipped due to cancelled context and left as-is") + } } type proxyAdapter struct { @@ -261,19 +257,15 @@ func TestFindMissingCasBlobsWithProxyFailFast(t *testing.T) { t.Fatal(err) } - // Explicitly avoid using WithProxyBackEnd, as we want to control the workers. + // Explicitly avoid using WithProxyBackEnd, as we want to control the concurrency limit. testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger())) if err != nil { t.Fatal(err) } actualDiskCache := testCacheI.(*diskCache) actualDiskCache.proxy = proxy - actualDiskCache.containsQueue = make(chan proxyCheck, 4) - defer func() { - close(actualDiskCache.containsQueue) - }() - // Spawn a single worker. - go actualDiskCache.containsWorker() + // Allow only a single Contains check in flight at a time. + actualDiskCache.containsSem = semaphore.NewWeighted(1) data1, digest1 := testutils.RandomDataAndDigest(100) _, digest2 := testutils.RandomDataAndDigest(200) @@ -320,19 +312,15 @@ func TestFindMissingCasBlobsWithProxyFailFastNoneMissing(t *testing.T) { t.Fatal(err) } - // Explicitly avoid using WithProxyBackEnd, as we want to control the workers. + // Explicitly avoid using WithProxyBackEnd, as we want to control the concurrency limit. testCacheI, err := New(cacheDir, 40*1024, WithAccessLogger(testutils.NewSilentLogger())) if err != nil { t.Fatal(err) } actualDiskCache := testCacheI.(*diskCache) actualDiskCache.proxy = proxy - actualDiskCache.containsQueue = make(chan proxyCheck, 4) - defer func() { - close(actualDiskCache.containsQueue) - }() - // Spawn a single worker. - go actualDiskCache.containsWorker() + // Allow only a single Contains check in flight at a time. + actualDiskCache.containsSem = semaphore.NewWeighted(1) data1, digest1 := testutils.RandomDataAndDigest(100) data2, digest2 := testutils.RandomDataAndDigest(200) @@ -394,19 +382,15 @@ func TestFindMissingCasBlobsWithProxyFailFastMaxProxyBlobSize(t *testing.T) { t.Fatal(err) } - // Explicitly avoid using WithProxyBackEnd, as we want to control the workers. + // Explicitly avoid using WithProxyBackEnd, as we want to control the concurrency limit. testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger()), WithProxyMaxBlobSize(300)) if err != nil { t.Fatal(err) } actualDiskCache := testCacheI.(*diskCache) actualDiskCache.proxy = proxy - actualDiskCache.containsQueue = make(chan proxyCheck, 4) - defer func() { - close(actualDiskCache.containsQueue) - }() - // Spawn a single worker. - go actualDiskCache.containsWorker() + // Allow only a single Contains check in flight at a time. + actualDiskCache.containsSem = semaphore.NewWeighted(1) data1, digest1 := testutils.RandomDataAndDigest(100) data2, digest2 := testutils.RandomDataAndDigest(200) @@ -489,3 +473,234 @@ func TestFindMissingCasBlobsWithProxyMaxProxyBlobSize(t *testing.T) { t.Fatalf("Expected missing[0] == digest2, got %+v", missing[0]) } } + +// gatedProxy reports every blob as missing, blocking each Contains call until +// a token is sent on gate. It tracks the high-water mark of concurrent calls. +type gatedProxy struct { + gate chan struct{} + inflight atomic.Int32 + high atomic.Int32 +} + +func (p *gatedProxy) Put(ctx context.Context, kind cache.EntryKind, hash string, logicalSize int64, sizeOnDisk int64, rc io.ReadCloser) { +} + +func (p *gatedProxy) Get(ctx context.Context, kind cache.EntryKind, hash string, _ int64) (io.ReadCloser, int64, error) { + return nil, -1, nil +} + +func (p *gatedProxy) Contains(ctx context.Context, kind cache.EntryKind, hash string, _ int64) (bool, int64) { + n := p.inflight.Add(1) + defer p.inflight.Add(-1) + for { + h := p.high.Load() + if n <= h || p.high.CompareAndSwap(h, n) { + break + } + } + <-p.gate + return false, -1 +} + +func waitForInflight(t *testing.T, p *gatedProxy, want int32) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for p.inflight.Load() != want { + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for %d in-flight Contains calls, have %d", + want, p.inflight.Load()) + } + time.Sleep(time.Millisecond) + } +} + +func TestFindMissingCasBlobsConcurrencyLimit(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cacheDir := tempDir(t) + defer os.RemoveAll(cacheDir) + + testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger())) + if err != nil { + t.Fatal(err) + } + c := testCacheI.(*diskCache) + + const limit = 4 + const numBlobs = 32 + + proxy := &gatedProxy{gate: make(chan struct{})} + c.proxy = proxy + c.containsSem = semaphore.NewWeighted(limit) + + blobs := make([]*pb.Digest, numBlobs) + for i := range blobs { + _, digest := testutils.RandomDataAndDigest(int64(100 + i)) + blobs[i] = &digest + } + + done := make(chan error, 1) + go func() { + done <- c.findMissingCasBlobsInternal(ctx, blobs, false) + }() + + // With every Contains call gated, the producer must saturate the + // semaphore and park. + waitForInflight(t, proxy, limit) + + for i := 0; i < numBlobs; i++ { + proxy.gate <- struct{}{} + } + + if err := <-done; err != nil { + t.Fatal(err) + } + + if got := proxy.high.Load(); got != limit { + t.Errorf("Expected high-water mark of concurrent Contains calls to be %d, got %d", limit, got) + } + + if got := len(filterNonNil(blobs)); got != numBlobs { + t.Errorf("Expected all %d blobs to be reported missing, got %d", numBlobs, got) + } +} + +// slowMissProxy reports every blob as missing after a fixed delay. +type slowMissProxy struct { + delay time.Duration +} + +func (p *slowMissProxy) Put(ctx context.Context, kind cache.EntryKind, hash string, logicalSize int64, sizeOnDisk int64, rc io.ReadCloser) { +} + +func (p *slowMissProxy) Get(ctx context.Context, kind cache.EntryKind, hash string, _ int64) (io.ReadCloser, int64, error) { + return nil, -1, nil +} + +func (p *slowMissProxy) Contains(ctx context.Context, kind cache.EntryKind, hash string, _ int64) (bool, int64) { + time.Sleep(p.delay) + return false, -1 +} + +func TestFindMissingCasBlobsFailFastWakesBlockedProducer(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cacheDir := tempDir(t) + defer os.RemoveAll(cacheDir) + + testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger())) + if err != nil { + t.Fatal(err) + } + c := testCacheI.(*diskCache) + c.proxy = &slowMissProxy{delay: 50 * time.Millisecond} + c.containsSem = semaphore.NewWeighted(1) + + _, digest1 := testutils.RandomDataAndDigest(100) + _, digest2 := testutils.RandomDataAndDigest(200) + _, digest3 := testutils.RandomDataAndDigest(300) + + // With a concurrency limit of 1, the first check occupies the + // semaphore for the proxy delay while the producer parks in Acquire + // for the second. The first check's miss triggers the fail-fast + // cancellation, which must wake the parked producer. + start := time.Now() + err = c.findMissingCasBlobsInternal(ctx, []*pb.Digest{&digest1, &digest2, &digest3}, true) + elapsed := time.Since(start) + + if !errors.Is(err, errMissingBlob) { + t.Fatalf("Expected err to be errMissingBlob, got: %s", err) + } + + if elapsed > 5*time.Second { + t.Fatalf("Expected fail-fast to return promptly, took %s", elapsed) + } +} + +func TestFindMissingCasBlobsCancelWakesBlockedProducer(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + + cacheDir := tempDir(t) + defer os.RemoveAll(cacheDir) + + testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger())) + if err != nil { + t.Fatal(err) + } + c := testCacheI.(*diskCache) + + proxy := &gatedProxy{gate: make(chan struct{})} + c.proxy = proxy + c.containsSem = semaphore.NewWeighted(1) + // Unblock the in-flight Contains call when the test finishes. + defer close(proxy.gate) + + _, digest1 := testutils.RandomDataAndDigest(100) + _, digest2 := testutils.RandomDataAndDigest(200) + + done := make(chan error, 1) + go func() { + done <- c.findMissingCasBlobsInternal(ctx, []*pb.Digest{&digest1, &digest2}, false) + }() + + // Wait until the first check is in flight, so the producer is parked + // in Acquire for the second, then cancel. + waitForInflight(t, proxy, 1) + cancel() + + if err := <-done; !errors.Is(err, errRequestCancelled) { + t.Fatalf("Expected err to be errRequestCancelled, got: %s", err) + } +} + +func TestFindMissingCasBlobsConcurrentCallers(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cacheDir := tempDir(t) + defer os.RemoveAll(cacheDir) + + testCacheI, err := New(cacheDir, 10*1024, WithAccessLogger(testutils.NewSilentLogger())) + if err != nil { + t.Fatal(err) + } + c := testCacheI.(*diskCache) + c.proxy = &testCWProxy{blob: "no-such-blob"} + c.containsSem = semaphore.NewWeighted(2) + + const numCallers = 16 + const blobsPerCaller = 200 + + var wg sync.WaitGroup + errs := make(chan error, numCallers) + for i := 0; i < numCallers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + blobs := make([]*pb.Digest, blobsPerCaller) + for j := range blobs { + _, digest := testutils.RandomDataAndDigest(int64(100 + j)) + blobs[j] = &digest + } + missing, err := c.FindMissingCasBlobs(ctx, blobs) + if err != nil { + errs <- err + return + } + if len(missing) != blobsPerCaller { + errs <- fmt.Errorf("expected %d missing blobs, got %d", blobsPerCaller, len(missing)) + } + }() + } + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } +} diff --git a/cache/disk/options.go b/cache/disk/options.go index 9d28656..7e66e28 100644 --- a/cache/disk/options.go +++ b/cache/disk/options.go @@ -59,7 +59,7 @@ func WithProxyBackend(proxy cache.Proxy) Option { if proxy != nil { c.diskCache.proxy = proxy - c.diskCache.spawnContainsQueueWorkers() + c.diskCache.initContainsCheckLimiter() } return nil diff --git a/cache/grpcproxy/grpcproxy_test.go b/cache/grpcproxy/grpcproxy_test.go index 7a1abc8..ba82977 100644 --- a/cache/grpcproxy/grpcproxy_test.go +++ b/cache/grpcproxy/grpcproxy_test.go @@ -244,7 +244,7 @@ func newFixture(t *testing.T, proxy cache.Proxy, storageMode string) *fixture { go func() { err := server.ServeGRPC(listener, grpcServer, false, false, true, diskCache, logger, logger) if err != nil { - logger.Printf(err.Error()) + logger.Print(err.Error()) } }()