-
Notifications
You must be signed in to change notification settings - Fork 12.9k
feat: Add CacheProvider API for external distributed caching #12056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
6540aa0
feat: Add CacheProvider API for external distributed caching
deepme987 e17571d
fix: use deterministic hash for cache keys instead of pickle
deepme987 5e4bbca
test: add unit tests for CacheProvider API
deepme987 f4623c0
style: remove unused imports in test_cache_provider.py
deepme987 17eed38
fix: move _torch_available before usage and use importlib.util.find_spec
deepme987 dcf6868
fix: use hashable types in frozenset test and add dict test
deepme987 9b0ca8b
Merge remote-tracking branch 'origin/master' into feat/cache-provider…
deepme987 2049066
refactor: expose CacheProvider API via comfy_api.latest.Caching
deepme987 d755f7c
docs: clarify should_cache filtering criteria
deepme987 4afa80d
docs: make should_cache docstring implementation-agnostic
deepme987 0440ebc
feat: add optional ui field to CacheValue
deepme987 0141af0
refactor: rename _is_cacheable_value to _is_external_cacheable_value
deepme987 4cbe4fe
refactor: async CacheProvider API + reduce public surface
deepme987 da51486
fix: remove unused imports (ruff) and update tests for internal API
deepme987 04097e6
fix: address coderabbit review feedback
deepme987 f5c5ff5
Merge remote-tracking branch 'origin/master' into feat/cache-provider…
deepme987 c50f02c
fix: use _-prefixed imports in _notify_prompt_lifecycle
deepme987 33a0cc2
fix: add sync get_local/set_local for graph traversal
deepme987 311a2d5
chore: remove cloud-specific language from cache provider API
deepme987 26f34d8
style: align documentation with codebase conventions
deepme987 66ad993
Merge branch 'master' into feat/cache-provider-api
deepme987 c73e3c9
fix: add usage example to Caching class, remove pickle fallback
deepme987 8ed3386
refactor: move public types to comfy_api, eager provider snapshot
deepme987 15a23ad
fix: generalize self-inequality check, fail-closed canonicalization
deepme987 71fd8f7
fix: suppress ruff F401 for re-exported CacheContext
deepme987 7c3c427
fix: enable external caching for subcache (expanded) nodes
deepme987 9586c79
fix: wrap register/unregister as explicit static methods
deepme987 01705f5
fix: use debug-level logging for provider registration
deepme987 1e971fd
fix: follow ProxiedSingleton pattern for Caching class
deepme987 2c34c89
fix: inline registration logic in Caching class
deepme987 476538a
fix: single Caching definition inside ComfyAPI_latest
deepme987 832d3ef
fix: remove prompt_id from CacheContext, type-safe canonicalization
deepme987 0e912e5
Merge branch 'master' into feat/cache-provider-api
deepme987 3891064
fix: address review feedback on cache provider API
deepme987 3361d70
Merge branch 'master' into feat/cache-provider-api
deepme987 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Optional | ||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class CacheContext: | ||
| node_id: str | ||
| class_type: str | ||
| cache_key_hash: str # SHA256 hex digest | ||
|
|
||
|
|
||
| @dataclass | ||
| class CacheValue: | ||
| outputs: list | ||
| ui: dict = None | ||
|
|
||
|
|
||
| class CacheProvider(ABC): | ||
| """Abstract base class for external cache providers. | ||
| Exceptions from provider methods are caught by the caller and never break execution. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: | ||
| """Called on local cache miss. Return CacheValue if found, None otherwise.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def on_store(self, context: CacheContext, value: CacheValue) -> None: | ||
| """Called after local store. Dispatched via asyncio.create_task.""" | ||
| pass | ||
|
|
||
| def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: | ||
| """Return False to skip external caching for this node. Default: True.""" | ||
| return True | ||
|
|
||
| def on_prompt_start(self, prompt_id: str) -> None: | ||
| pass | ||
|
|
||
| def on_prompt_end(self, prompt_id: str) -> None: | ||
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from typing import Any, Optional, Tuple, List | ||
| import hashlib | ||
| import json | ||
| import logging | ||
| import threading | ||
|
|
||
| # Public types — source of truth is comfy_api.latest._caching | ||
| from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) | ||
|
|
||
| _logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| _providers: List[CacheProvider] = [] | ||
| _providers_lock = threading.Lock() | ||
| _providers_snapshot: Tuple[CacheProvider, ...] = () | ||
|
|
||
|
|
||
| def register_cache_provider(provider: CacheProvider) -> None: | ||
|
deepme987 marked this conversation as resolved.
|
||
| """Register an external cache provider. Providers are called in registration order.""" | ||
| global _providers_snapshot | ||
| with _providers_lock: | ||
| if provider in _providers: | ||
| _logger.warning(f"Provider {provider.__class__.__name__} already registered") | ||
| return | ||
| _providers.append(provider) | ||
| _providers_snapshot = tuple(_providers) | ||
| _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") | ||
|
|
||
|
|
||
| def unregister_cache_provider(provider: CacheProvider) -> None: | ||
| global _providers_snapshot | ||
| with _providers_lock: | ||
| try: | ||
| _providers.remove(provider) | ||
| _providers_snapshot = tuple(_providers) | ||
| _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") | ||
| except ValueError: | ||
| _logger.warning(f"Provider {provider.__class__.__name__} was not registered") | ||
|
|
||
|
|
||
| def _get_cache_providers() -> Tuple[CacheProvider, ...]: | ||
| return _providers_snapshot | ||
|
|
||
|
|
||
| def _has_cache_providers() -> bool: | ||
| return bool(_providers_snapshot) | ||
|
|
||
|
|
||
| def _clear_cache_providers() -> None: | ||
| global _providers_snapshot | ||
| with _providers_lock: | ||
| _providers.clear() | ||
| _providers_snapshot = () | ||
|
|
||
|
|
||
| def _canonicalize(obj: Any) -> Any: | ||
| # Convert to canonical JSON-serializable form with deterministic ordering. | ||
| # Frozensets have non-deterministic iteration order between Python sessions. | ||
| # Raises ValueError for non-cacheable types (Unhashable, unknown) so that | ||
| # _serialize_cache_key returns None and external caching is skipped. | ||
| if isinstance(obj, frozenset): | ||
| return ("__frozenset__", sorted( | ||
| [_canonicalize(item) for item in obj], | ||
| key=lambda x: json.dumps(x, sort_keys=True) | ||
| )) | ||
| elif isinstance(obj, set): | ||
| return ("__set__", sorted( | ||
| [_canonicalize(item) for item in obj], | ||
| key=lambda x: json.dumps(x, sort_keys=True) | ||
| )) | ||
| elif isinstance(obj, tuple): | ||
| return ("__tuple__", [_canonicalize(item) for item in obj]) | ||
| elif isinstance(obj, list): | ||
| return [_canonicalize(item) for item in obj] | ||
| elif isinstance(obj, dict): | ||
| return {"__dict__": sorted( | ||
| [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], | ||
| key=lambda x: json.dumps(x, sort_keys=True) | ||
| )} | ||
| elif isinstance(obj, (int, float, str, bool, type(None))): | ||
| return (type(obj).__name__, obj) | ||
| elif isinstance(obj, bytes): | ||
| return ("__bytes__", obj.hex()) | ||
| else: | ||
| raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") | ||
|
|
||
|
|
||
| def _serialize_cache_key(cache_key: Any) -> Optional[str]: | ||
| # Returns deterministic SHA256 hex digest, or None on failure. | ||
| # Uses JSON (not pickle) because pickle is non-deterministic across sessions. | ||
| try: | ||
| canonical = _canonicalize(cache_key) | ||
| json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) | ||
| return hashlib.sha256(json_str.encode('utf-8')).hexdigest() | ||
| except Exception as e: | ||
| _logger.warning(f"Failed to serialize cache key: {e}") | ||
| return None | ||
|
|
||
|
|
||
| def _contains_self_unequal(obj: Any) -> bool: | ||
| # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will | ||
| # never hit locally, but serialized form would match externally. Skip these. | ||
| try: | ||
| if not (obj == obj): | ||
| return True | ||
| except Exception: | ||
| return True | ||
| if isinstance(obj, (frozenset, tuple, list, set)): | ||
| return any(_contains_self_unequal(item) for item in obj) | ||
| if isinstance(obj, dict): | ||
| return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) | ||
| if hasattr(obj, 'value'): | ||
| return _contains_self_unequal(obj.value) | ||
| return False | ||
|
|
||
|
|
||
| def _estimate_value_size(value: CacheValue) -> int: | ||
| try: | ||
| import torch | ||
| except ImportError: | ||
| return 0 | ||
|
|
||
| total = 0 | ||
|
|
||
| def estimate(obj): | ||
| nonlocal total | ||
| if isinstance(obj, torch.Tensor): | ||
| total += obj.numel() * obj.element_size() | ||
| elif isinstance(obj, dict): | ||
| for v in obj.values(): | ||
| estimate(v) | ||
| elif isinstance(obj, (list, tuple)): | ||
| for item in obj: | ||
| estimate(item) | ||
|
|
||
| for output in value.outputs: | ||
| estimate(output) | ||
| return total | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.