diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go index 195b0d7..835547d 100644 --- a/drivers/store/redis/store.go +++ b/drivers/store/redis/store.go @@ -37,6 +37,15 @@ if v == false then end local ttl = redis.call("pttl", key) return {tonumber(v), ttl} +` + luaResetScript = ` +local key = KEYS[1] +local ttl = tonumber(ARGV[1]) +redis.call("set", key, 0) +if ttl > 0 then + redis.call("pexpire", key, ttl) +end +return {0, ttl} ` ) @@ -68,6 +77,8 @@ type Store struct { luaIncrSHA string // luaPeekSHA is the SHA of peek and expire key script. luaPeekSHA string + // luaResetSHA is the SHA of reset key script. + luaResetSHA string } // NewStore returns an instance of redis store with defaults. @@ -126,16 +137,8 @@ func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (li // Reset returns the limit for given identifier which is set to zero. func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { - _, err := store.client.Del(ctx, store.getCacheKey(key)).Result() - if err != nil { - return limiter.Context{}, err - } - - count := int64(0) - now := time.Now() - expiration := now.Add(rate.Period) - - return common.GetContextFromState(now, rate, expiration, count), nil + cmd := store.evalSHA(ctx, store.getLuaResetSHA, []string{store.getCacheKey(key)}, rate.Period.Milliseconds()) + return currentContext(cmd, rate) } // getCacheKey returns the full path for an identifier. @@ -147,7 +150,7 @@ func (store *Store) getCacheKey(key string) string { return buffer.String() } -// preloadLuaScripts preloads the "incr" and "peek" lua scripts. +// preloadLuaScripts preloads the "incr", "peek" and "reset" lua scripts. func (store *Store) preloadLuaScripts(ctx context.Context) error { // Verify if we need to load lua scripts. // Inspired by sync.Once. @@ -165,7 +168,7 @@ func (store *Store) reloadLuaScripts(ctx context.Context) error { return store.loadLuaScripts(ctx) } -// loadLuaScripts load "incr" and "peek" lua scripts. +// loadLuaScripts load "incr", "peek" and "reset" lua scripts. // WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one. func (store *Store) loadLuaScripts(ctx context.Context) error { store.luaMutex.Lock() @@ -186,8 +189,14 @@ func (store *Store) loadLuaScripts(ctx context.Context) error { return errors.Wrap(err, `failed to load "peek" lua script`) } + luaResetSHA, err := store.client.ScriptLoad(ctx, luaResetScript).Result() + if err != nil { + return errors.Wrap(err, `failed to load "reset" lua script`) + } + store.luaIncrSHA = luaIncrSHA store.luaPeekSHA = luaPeekSHA + store.luaResetSHA = luaResetSHA atomic.StoreUint32(&store.luaLoaded, 1) @@ -208,6 +217,13 @@ func (store *Store) getLuaPeekSHA() string { return store.luaPeekSHA } +// getLuaResetSHA returns a "thread-safe" value for luaResetSHA. +func (store *Store) getLuaResetSHA() string { + store.luaMutex.RLock() + defer store.luaMutex.RUnlock() + return store.luaResetSHA +} + // evalSHA eval the redis lua sha and load the scripts if missing. func (store *Store) evalSHA(ctx context.Context, getSha func() string, keys []string, args ...interface{}) *libredis.Cmd {