Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ jobs:
virtualenvs-create: true
virtualenvs-in-project: true

- name: Set cache key date
id: cache-date
run: echo "CACHE_DATE=$(date +%Y-%U)" >> $GITHUB_ENV

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ env.CACHE_DATE }}-${{ hashFiles('**/poetry.lock') }}


- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
run: poetry install --no-interaction --no-root --with dev

- name: Run lint
run: |
Expand Down
21 changes: 11 additions & 10 deletions fast_cache_middleware/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import http
import logging
import typing as tp
import re
from hashlib import blake2b
from typing import Optional

from starlette.requests import Request
from starlette.responses import Response

from .depends import CacheConfig, CacheDropConfig
from .schemas import CacheConfiguration
from .storages import BaseStorage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,7 +137,7 @@ async def is_cachable_response(self, response: Response) -> bool:
return True

async def generate_cache_key(
self, request: Request, cache_config: CacheConfig
self, request: Request, cache_configuration: CacheConfiguration
) -> str:
"""Generates cache key for request.

Expand All @@ -148,8 +149,8 @@ async def generate_cache_key(
str: Cache key
"""
# Use custom key generation function if available
if cache_config.key_func:
return cache_config.key_func(request)
if cache_configuration.key_func:
return cache_configuration.key_func(request)

# Use standard function
return generate_key(request)
Expand All @@ -160,7 +161,7 @@ async def cache_response(
request: Request,
response: Response,
storage: BaseStorage,
ttl: tp.Optional[int] = None,
ttl: Optional[int] = None,
) -> None:
"""Saves response to cache.

Expand All @@ -180,7 +181,7 @@ async def cache_response(

async def get_cached_response(
self, cache_key: str, storage: BaseStorage
) -> tp.Optional[Response]:
) -> Optional[Response]:
"""Gets cached response if it exists and is valid.

Args:
Expand All @@ -198,13 +199,13 @@ async def get_cached_response(

async def invalidate_cache(
self,
cache_drop_config: CacheDropConfig,
invalidate_paths: list[re.Pattern],
storage: BaseStorage,
) -> None:
"""Invalidates cache by configuration.

Args:
cache_drop_config: Cache invalidation configuration
invalidate_paths: List of regex patterns for cache invalidation
storage: Cache storage

TODO: Comments on improvements:
Expand All @@ -226,6 +227,6 @@ async def invalidate_cache(
5. Add tag support for grouping related caches
and their joint invalidation
"""
for path in cache_drop_config.paths:
for path in invalidate_paths:
await storage.remove(path)
logger.info("Invalidated cache for pattern: %s", path.pattern)
4 changes: 2 additions & 2 deletions fast_cache_middleware/depends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
import typing as tp
from typing import Callable, Optional

from fastapi import params
from starlette.requests import Request
Expand Down Expand Up @@ -29,7 +29,7 @@ class CacheConfig(BaseCacheConfigDepends):
def __init__(
self,
max_age: int = 5 * 60,
key_func: tp.Optional[tp.Callable[[Request], str]] = None,
key_func: Optional[Callable[[Request], str]] = None,
) -> None:
self.max_age = max_age
self.key_func = key_func
Expand Down
31 changes: 21 additions & 10 deletions fast_cache_middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._helpers import set_cache_age_in_openapi_schema
from .controller import Controller
from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig
from .schemas import RouteInfo
from .schemas import CacheConfiguration, RouteInfo
from .storages import BaseStorage, InMemoryStorage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,19 +247,23 @@ async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | No
if not route_info:
return None

cache_configuration = route_info.cache_config

# Handle invalidation if specified
if cc := route_info.cache_drop_config:
await self.controller.invalidate_cache(cc, storage=self.storage)
if cache_configuration.invalidate_paths:
await self.controller.invalidate_cache(
cache_configuration.invalidate_paths, storage=self.storage
)

# Handle caching if config exists
cache_config = route_info.cache_config
if not cache_config:
if not cache_configuration.max_age:
return None

if not await self.controller.is_cachable_request(request):
return None

cache_key = await self.controller.generate_cache_key(request, cache_config)
cache_key = await self.controller.generate_cache_key(
request, cache_configuration=cache_configuration
)

cached_response = await self.controller.get_cached_response(
cache_key, self.storage
Expand All @@ -279,7 +283,7 @@ async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | No
storage=self.storage,
request=request,
cache_key=cache_key,
ttl=cache_config.max_age,
ttl=cache_configuration.max_age,
)()
return True

Expand All @@ -297,10 +301,17 @@ def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo
) = self._extract_cache_configs_from_route(route)

if cache_config or cache_drop_config:
cache_configuration = CacheConfiguration(
max_age=cache_config.max_age if cache_config else None,
key_func=cache_config.key_func if cache_config else None,
invalidate_paths=(
cache_drop_config.paths if cache_drop_config else None
),
)

route_info = RouteInfo(
route=route,
cache_config=cache_config,
cache_drop_config=cache_drop_config,
cache_config=cache_configuration,
)
routes_info.append(route_info)

Expand Down
83 changes: 70 additions & 13 deletions fast_cache_middleware/schemas.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,78 @@
import typing as tp
import re
from typing import Any, Callable

from pydantic import (
BaseModel,
ConfigDict,
Field,
computed_field,
field_validator,
model_validator,
)
from starlette.requests import Request
from starlette.routing import Route

from .depends import CacheConfig, CacheDropConfig


class RouteInfo:
class CacheConfiguration(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

max_age: int | None = Field(
default=None,
description="Cache lifetime in seconds. If None, caching is disabled.",
)
key_func: Callable[[Request], str] | None = Field(
default=None,
description="Custom cache key generation function. If None, default key generation is used.",
)
invalidate_paths: list[re.Pattern] | None = Field(
default=None,
description="Paths for cache invalidation (strings or regex patterns). No invalidation if None.",
)

@model_validator(mode="after")
def one_of_field_is_set(self) -> "CacheConfiguration":
if (
self.max_age is None
and self.key_func is None
and self.invalidate_paths is None
):
raise ValueError(
"At least one of max_age, key_func, or invalidate_paths must be set."
)
return self

@field_validator("invalidate_paths")
@classmethod
def compile_paths(cls, item: Any) -> Any:
if item is None:
return None
if isinstance(item, str):
return re.compile(f"^{item}")
if isinstance(item, re.Pattern):
return item
if isinstance(item, list):
return [cls.compile_paths(i) for i in item]
raise ValueError(
"invalidate_paths must be a string, regex pattern, or list of them."
)


class RouteInfo(BaseModel):
"""Route information with cache configuration."""

def __init__(
self,
route: Route,
cache_config: CacheConfig | None = None,
cache_drop_config: CacheDropConfig | None = None,
):
self.route = route
self.cache_config = cache_config
self.cache_drop_config = cache_drop_config
self.path: str = getattr(route, "path")
self.methods: tp.Set[str] = getattr(route, "methods", set())
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

route: Route
cache_config: CacheConfiguration

@computed_field # type: ignore[prop-decorator]
@property
def path(self) -> str:
return getattr(self.route, "path", "")

@computed_field # type: ignore[prop-decorator]
@property
def methods(self) -> set[str]:
return getattr(self.route, "methods", set())
16 changes: 7 additions & 9 deletions fast_cache_middleware/serializers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import json
import typing as tp
from typing import Any, Callable, Dict, Optional, Tuple, TypeAlias, Union

from starlette.requests import Request
from starlette.responses import Response

# Define types for metadata and stored response
Metadata: tp.TypeAlias = tp.Dict[str, tp.Any] # todo: make it models
StoredResponse: tp.TypeAlias = tp.Tuple[Response, Request, Metadata]
Metadata: TypeAlias = Dict[str, Any] # todo: make it models
StoredResponse: TypeAlias = Tuple[Response, Request, Metadata]


class BaseSerializer:
def dumps(
self, response: Response, request: Request, metadata: Metadata
) -> tp.Union[str, bytes]:
) -> Union[str, bytes]:
raise NotImplementedError()

def loads(
self, data: tp.Union[str, bytes]
) -> tp.Tuple[Response, Request, Metadata]:
def loads(self, data: Union[str, bytes]) -> Tuple[Response, Request, Metadata]:
raise NotImplementedError()

@property
Expand All @@ -29,7 +27,7 @@ class JSONSerializer(BaseSerializer):
def dumps(self, response: Response, request: Request, metadata: Metadata) -> str:
raise NotImplementedError() # fixme: bad implementation now, maybe async?

def loads(self, data: tp.Union[str, bytes]) -> StoredResponse:
def loads(self, data: Union[str, bytes]) -> StoredResponse:
if isinstance(data, bytes):
data = data.decode()

Expand Down Expand Up @@ -63,7 +61,7 @@ def loads(self, data: tp.Union[str, bytes]) -> StoredResponse:
}

# Create empty receive function
async def receive() -> tp.Dict[str, tp.Any]:
async def receive() -> Dict[str, Any]:
return {"type": "http.request", "body": b""}

request = Request(scope, receive)
Expand Down
18 changes: 9 additions & 9 deletions fast_cache_middleware/storages.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import re
import time
import typing as tp
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple, Union

from starlette.requests import Request
from starlette.responses import Response
Expand All @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)

# Define type for stored response
StoredResponse: TypeAlias = tp.Tuple[Response, Request, Metadata]
StoredResponse: TypeAlias = Tuple[Response, Request, Metadata]


# Define base class for cache storage
Expand All @@ -28,8 +28,8 @@ class BaseStorage:

def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: Optional[BaseSerializer] = None,
ttl: Optional[Union[int, float]] = None,
) -> None:
self._serializer = serializer or JSONSerializer()

Expand All @@ -43,7 +43,7 @@ async def store(
) -> None:
raise NotImplementedError()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> Optional[StoredResponse]:
raise NotImplementedError()

async def remove(self, path: re.Pattern) -> None:
Expand All @@ -70,8 +70,8 @@ class InMemoryStorage(BaseStorage):
def __init__(
self,
max_size: int = 1000,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
serializer: Optional[BaseSerializer] = None,
ttl: Optional[Union[int, float]] = None,
) -> None:
super().__init__(serializer=serializer, ttl=ttl)

Expand All @@ -87,7 +87,7 @@ def __init__(
# OrderedDict for efficient LRU
self._storage: OrderedDict[str, StoredResponse] = OrderedDict()
# Separate expiry time storage for fast TTL checking
self._expiry_times: tp.Dict[str, float] = {}
self._expiry_times: Dict[str, float] = {}
self._last_expiry_check_time: float = 0
self._expiry_check_interval: float = 60

Expand Down Expand Up @@ -126,7 +126,7 @@ async def store(

self._cleanup_lru_items()

async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
async def retrieve(self, key: str) -> Optional[StoredResponse]:
"""Gets response from cache with lazy TTL checking.

Element moves to the end to update LRU position.
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading