Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6540aa0
feat: Add CacheProvider API for external distributed caching
deepme987 Jan 19, 2026
e17571d
fix: use deterministic hash for cache keys instead of pickle
deepme987 Jan 24, 2026
5e4bbca
test: add unit tests for CacheProvider API
deepme987 Jan 24, 2026
f4623c0
style: remove unused imports in test_cache_provider.py
deepme987 Jan 24, 2026
17eed38
fix: move _torch_available before usage and use importlib.util.find_spec
deepme987 Jan 24, 2026
dcf6868
fix: use hashable types in frozenset test and add dict test
deepme987 Jan 24, 2026
9b0ca8b
Merge remote-tracking branch 'origin/master' into feat/cache-provider…
deepme987 Jan 28, 2026
2049066
refactor: expose CacheProvider API via comfy_api.latest.Caching
deepme987 Jan 29, 2026
d755f7c
docs: clarify should_cache filtering criteria
deepme987 Jan 29, 2026
4afa80d
docs: make should_cache docstring implementation-agnostic
deepme987 Jan 29, 2026
0440ebc
feat: add optional ui field to CacheValue
deepme987 Jan 29, 2026
0141af0
refactor: rename _is_cacheable_value to _is_external_cacheable_value
deepme987 Jan 29, 2026
4cbe4fe
refactor: async CacheProvider API + reduce public surface
deepme987 Mar 3, 2026
da51486
fix: remove unused imports (ruff) and update tests for internal API
deepme987 Mar 3, 2026
04097e6
fix: address coderabbit review feedback
deepme987 Mar 3, 2026
f5c5ff5
Merge remote-tracking branch 'origin/master' into feat/cache-provider…
deepme987 Mar 3, 2026
c50f02c
fix: use _-prefixed imports in _notify_prompt_lifecycle
deepme987 Mar 3, 2026
33a0cc2
fix: add sync get_local/set_local for graph traversal
deepme987 Mar 4, 2026
311a2d5
chore: remove cloud-specific language from cache provider API
deepme987 Mar 4, 2026
26f34d8
style: align documentation with codebase conventions
deepme987 Mar 4, 2026
66ad993
Merge branch 'master' into feat/cache-provider-api
deepme987 Mar 4, 2026
c73e3c9
fix: add usage example to Caching class, remove pickle fallback
deepme987 Mar 4, 2026
8ed3386
refactor: move public types to comfy_api, eager provider snapshot
deepme987 Mar 4, 2026
15a23ad
fix: generalize self-inequality check, fail-closed canonicalization
deepme987 Mar 4, 2026
71fd8f7
fix: suppress ruff F401 for re-exported CacheContext
deepme987 Mar 4, 2026
7c3c427
fix: enable external caching for subcache (expanded) nodes
deepme987 Mar 4, 2026
9586c79
fix: wrap register/unregister as explicit static methods
deepme987 Mar 4, 2026
01705f5
fix: use debug-level logging for provider registration
deepme987 Mar 9, 2026
1e971fd
fix: follow ProxiedSingleton pattern for Caching class
deepme987 Mar 9, 2026
2c34c89
fix: inline registration logic in Caching class
deepme987 Mar 9, 2026
476538a
fix: single Caching definition inside ComfyAPI_latest
deepme987 Mar 9, 2026
832d3ef
fix: remove prompt_id from CacheContext, type-safe canonicalization
deepme987 Mar 10, 2026
0e912e5
Merge branch 'master' into feat/cache-provider-api
deepme987 Mar 10, 2026
3891064
fix: address review feedback on cache provider API
deepme987 Mar 12, 2026
3361d70
Merge branch 'master' into feat/cache-provider-api
deepme987 Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions comfy_api/latest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.caching = self.Caching()

class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
Expand Down Expand Up @@ -84,6 +85,36 @@ async def set_progress(
image=to_display,
)

class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.

Example::

from comfy_api.latest import Caching

class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage

async def on_store(self, context, value):
... # store to external storage

Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue

async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
Comment thread
deepme987 marked this conversation as resolved.
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)

async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)

class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Expand Down Expand Up @@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL
File3D = File3D


Caching = ComfyAPI_latest.Caching

ComfyAPI = ComfyAPI_latest

# Create a synchronous version of the API
Expand All @@ -135,6 +169,7 @@ class Types:
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",
Expand Down
42 changes: 42 additions & 0 deletions comfy_api/latest/_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass


@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest


@dataclass
class CacheValue:
outputs: list
ui: dict = None


class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""

@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass

@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass

def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True

def on_prompt_start(self, prompt_id: str) -> None:
pass

def on_prompt_end(self, prompt_id: str) -> None:
pass
138 changes: 138 additions & 0 deletions comfy_execution/cache_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading

# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)

_logger = logging.getLogger(__name__)


_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()


def register_cache_provider(provider: CacheProvider) -> None:
Comment thread
deepme987 marked this conversation as resolved.
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")


def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")


def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot


def _has_cache_providers() -> bool:
return bool(_providers_snapshot)


def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()


def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")


def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None


def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False


def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0

total = 0

def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)

for output in value.outputs:
estimate(output)
return total
Loading
Loading