diff --git a/shabda/cache.py b/shabda/cache.py index 265d3a3..e0b1ee4 100644 --- a/shabda/cache.py +++ b/shabda/cache.py @@ -14,7 +14,10 @@ def save(key, value, ttl=60 * 60 * 24): """Write a key:value pair to cache.""" # ttl is "time to live" of the item in seconds - filepath = CACHE_PATH + key + filepath = os.path.normpath(os.path.join(CACHE_PATH, key)) + # Ensure the normalized filepath is within CACHE_PATH + if not filepath.startswith(os.path.normpath(CACHE_PATH)): + raise Exception("Invalid cache key: path traversal detected") expiry = int(time.time() + ttl) cache = (expiry, value) with open(filepath, "w+b") as file: @@ -23,7 +26,10 @@ def save(key, value, ttl=60 * 60 * 24): def load(key): """Read the value for given key from cache.""" - filepath = CACHE_PATH + key + filepath = os.path.normpath(os.path.join(CACHE_PATH, key)) + # Ensure the normalized filepath is within CACHE_PATH + if not filepath.startswith(os.path.normpath(CACHE_PATH)): + raise Exception("Invalid cache key: path traversal detected") try: with open(filepath, "r+b") as file: expiry, value = pickle.load(file)