diff --git a/cache.go b/cache.go index 69edd02..dea114d 100644 --- a/cache.go +++ b/cache.go @@ -86,6 +86,47 @@ func (cache *Cache) Get(key []byte) (value []byte, err error) { return } +// MultiGet returns values and errors for the given keys. The returned slices +// have the same length as keys; values[i] and errs[i] correspond to keys[i]. +// A miss is represented by values[i] == nil and errs[i] == ErrNotFound. +// MultiGet reduces lock contention by grouping keys by segment and acquiring +// each segment lock at most once. +// Note that MultiGet holds each segment lock longer than a single Get (for +// the duration of all keys in that segment), which can increase Get tail +// latency when MultiGet and Get run concurrently. +func (cache *Cache) MultiGet(keys [][]byte) (values [][]byte, errs []error) { + n := len(keys) + if n == 0 { + return nil, nil + } + values = make([][]byte, n) + errs = make([]error, n) + type keyLoc struct { + idx int + hashVal uint64 + } + var groups [segmentCount][]keyLoc + for i, key := range keys { + hashVal := hashFunc(key) + segID := hashVal & segmentAndOpVal + groups[segID] = append(groups[segID], keyLoc{idx: i, hashVal: hashVal}) + } + for segID := 0; segID < segmentCount; segID++ { + batch := groups[segID] + if len(batch) == 0 { + continue + } + cache.locks[segID].Lock() + for _, loc := range batch { + value, _, err := cache.segments[segID].get(keys[loc.idx], nil, loc.hashVal, false) + values[loc.idx] = value + errs[loc.idx] = err + } + cache.locks[segID].Unlock() + } + return values, errs +} + // GetFn is equivalent to Get or GetWithBuf, but it attempts to be zero-copy, // calling the provided function with slice view over the current underlying // value of the key in memory. The slice is constrained in length and capacity. diff --git a/cache_test.go b/cache_test.go index 06119bc..49fb20c 100644 --- a/cache_test.go +++ b/cache_test.go @@ -315,6 +315,41 @@ func TestPeekWithExpiration(t *testing.T) { }) } +func TestMultiGet(t *testing.T) { + cache := NewCache(1024) + cache.Set([]byte("k1"), []byte("v1"), 0) + cache.Set([]byte("k2"), []byte("v2"), 0) + cache.Set([]byte("k3"), []byte("v3"), 0) + + keys := [][]byte{[]byte("k1"), []byte("k2"), []byte("missing"), []byte("k3")} + values, errs := cache.MultiGet(keys) + if len(values) != len(keys) || len(errs) != len(keys) { + t.Fatalf("len(values)=%d, len(errs)=%d, want %d", len(values), len(errs), len(keys)) + } + if errs[0] != nil || !bytes.Equal(values[0], []byte("v1")) { + t.Errorf("keys[0]: value=%q, err=%v", values[0], errs[0]) + } + if errs[1] != nil || !bytes.Equal(values[1], []byte("v2")) { + t.Errorf("keys[1]: value=%q, err=%v", values[1], errs[1]) + } + if errs[2] != ErrNotFound || values[2] != nil { + t.Errorf("keys[2] (missing): value=%v, err=%v", values[2], errs[2]) + } + if errs[3] != nil || !bytes.Equal(values[3], []byte("v3")) { + t.Errorf("keys[3]: value=%q, err=%v", values[3], errs[3]) + } + + // Empty keys + values, errs = cache.MultiGet(nil) + if values != nil || errs != nil { + t.Errorf("MultiGet(nil): got values=%v, errs=%v", values, errs) + } + values, errs = cache.MultiGet([][]byte{}) + if values != nil || errs != nil { + t.Errorf("MultiGet(empty): got values=%v, errs=%v", values, errs) + } +} + func TestGetWithExpirationAndBuf(t *testing.T) { cache := NewCache(1024) key := []byte("abcd") @@ -1070,6 +1105,65 @@ func BenchmarkCacheGetWithExpiration(b *testing.B) { } } +const ( + benchDataSize = 100_00 + benchBatchSize = 100 +) + +func BenchmarkParallelCacheGetBatched(b *testing.B) { + b.ReportAllocs() + b.StopTimer() + cache := NewCache(256 * 1024 * 1024) + buf := make([]byte, 64) + var key [8]byte + for i := 0; i < benchDataSize; i++ { + binary.LittleEndian.PutUint64(key[:], uint64(i)) + cache.Set(key[:], buf, 0) + } + b.StartTimer() + b.RunParallel(func(pb *testing.PB) { + keys := make([][]byte, benchBatchSize) + for i := range keys { + keys[i] = make([]byte, 8) + } + i := 0 + for pb.Next() { + for j := 0; j < benchBatchSize; j++ { + binary.LittleEndian.PutUint64(key[:], uint64((i+j)%benchDataSize)) + cache.Get(key[:]) + } + i++ + } + }) +} + +func BenchmarkParallelCacheMultiGetBatched(b *testing.B) { + b.ReportAllocs() + b.StopTimer() + cache := NewCache(256 * 1024 * 1024) + buf := make([]byte, 64) + var key [8]byte + for i := 0; i < benchDataSize; i++ { + binary.LittleEndian.PutUint64(key[:], uint64(i)) + cache.Set(key[:], buf, 0) + } + b.StartTimer() + b.RunParallel(func(pb *testing.PB) { + keys := make([][]byte, benchBatchSize) + for i := range keys { + keys[i] = make([]byte, 8) + } + i := 0 + for pb.Next() { + for j := 0; j < benchBatchSize; j++ { + binary.LittleEndian.PutUint64(keys[j], uint64((i+j)%benchDataSize)) + } + cache.MultiGet(keys) + i++ + } + }) +} + func BenchmarkMapGet(b *testing.B) { b.StopTimer() m := make(map[string][]byte)