Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions helion/autotuner/base_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .. import exc
from .._utils import counters
from .base_search import BaseAutotuner
from .base_search import performance

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -171,14 +172,53 @@ def _list_cache_entries(self) -> Sequence[tuple[str, CacheKeyBase]]:
"""Return a sequence of (description, key) tuples for all cache entries."""
raise NotImplementedError

def _handle_early_termination(self) -> Config:
try:
if not hasattr(self.autotuner, "population"):
raise AttributeError("No population available")

# Only consider members that have been benchmarked (have perfs data)
benchmarked = [m for m in self.autotuner.population if m.perfs]

if not benchmarked:
raise ValueError("No benchmarked configurations available")

best_member = min(benchmarked, key=performance)
config = best_member.config

# Log the early termination with details
self.autotuner.log(
f"User-initiated early termination. "
f"Saving best configuration from {len(benchmarked)} benchmarked configs "
f"(perf={best_member.perf:.4f}ms)"
)

return config

except (AttributeError, ValueError, IndexError) as e:
# Something went wrong accessing best config
# This can happen if interrupted before any configs completed
self.autotuner.log(
f"Early termination without valid best config: {e}. "
"No cached result will be saved."
)
raise KeyboardInterrupt from e

def autotune(self, *, skip_cache: bool = False) -> Config:
if skip_cache or os.environ.get("HELION_SKIP_CACHE", "") not in {
"",
"0",
"false",
"False",
}:
return self.autotuner.autotune()
# Add interrupt handling for skip_cache path
try:
return self.autotuner.autotune()
except KeyboardInterrupt:
config = self._handle_early_termination()
self.put(config)
counters["autotune"]["cache_put"] += 1
return config

if (config := self.get()) is not None:
counters["autotune"]["cache_hit"] += 1
Expand Down Expand Up @@ -235,7 +275,10 @@ def autotune(self, *, skip_cache: bool = False) -> Config:

self.autotuner.log("Starting autotuning process, this may take a while...")

config = self.autotuner.autotune()
try:
config = self.autotuner.autotune()
except KeyboardInterrupt:
config = self._handle_early_termination()

self.put(config)
counters["autotune"]["cache_put"] += 1
Expand Down