Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import collections
from collections.abc import Iterator
import contextlib
import dataclasses
import datetime
Expand Down Expand Up @@ -435,6 +436,7 @@ def __init__(self, kernel: _AutotunableKernel, args: Sequence[object]) -> None:
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
self._precompile_args_path: str | None = None
self._precompile_result_counter = count()
self._bad_config_strs: set[str] = set()

def _prepare(self) -> None:
"""Some initialization deferred until autotuning actually runs.
Expand Down Expand Up @@ -531,9 +533,50 @@ def _try_load_checkpoint(self) -> bool:
# load_state_dict validates required keys and raises CheckpointError for issues
self.load_state_dict(state)

# Load bad configs (from subprocess crash recovery)
self._load_bad_configs()

self.log(f"Resumed at generation {self._current_generation}")
return True

def _load_bad_configs(self) -> None:
"""Load bad configs from _bad_configs.txt file."""
from .subprocess_runner import load_bad_configs

checkpoint_dir_str = self.settings.autotune_checkpoint_dir
if checkpoint_dir_str is not None:
bad_configs_path = os.path.join(checkpoint_dir_str, "_bad_configs.txt")
self._bad_config_strs |= load_bad_configs(bad_configs_path)

if self._bad_config_strs:
self.log(
f"Loaded {len(self._bad_config_strs)} bad config(s) to skip",
)

@contextlib.contextmanager
def _pending_config(self, config: Config) -> Iterator[None]:
"""Context manager that writes the pending-config breadcrumb on entry
and removes it on exit.

If the body raises TritonUnrecoverableRuntimeError the pending file
is intentionally *not* cleared so the external crash-recovery script
can detect it.
"""
from .subprocess_runner import clear_pending, write_pending

checkpoint_dir_str = self.settings.autotune_checkpoint_dir
if checkpoint_dir_str is None:
yield
return
write_pending(checkpoint_dir_str, str(config))
try:
yield
except exc.TritonUnrecoverableRuntimeError:
# Let the pending file survive for the bash crash-recovery script
raise
else:
clear_pending(checkpoint_dir_str)

def _compute_baseline(
self,
) -> tuple[object, Sequence[int], Sequence[object] | None]:
Expand Down Expand Up @@ -752,6 +795,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
Returns:
The performance of the configuration in ms.
"""
# Skip configs that previously crashed the subprocess
config_str = str(config)
if config_str in self._bad_config_strs:
self.log.warning(f"Skipping known-bad config: {config}")
return inf

self._autotune_metrics.num_configs_tested += 1
self.counters["benchmark"] += 1
self.log.debug(lambda: f"Running benchmark for {config!r}")
Expand Down Expand Up @@ -1089,7 +1138,8 @@ def _benchmark(
)
)
# benchmark one-by-one to avoid noisy results
perf = self.benchmark_function(config, fn)
with self._pending_config(config):
perf = self.benchmark_function(config, fn)
status = "ok" if math.isfinite(perf) else "error"
# Log completion after benchmarking
self.log.record_autotune_entry(
Expand Down Expand Up @@ -1193,6 +1243,8 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
exit_stack.callback(self.cleanup)

if not self._try_load_checkpoint():
# Load bad configs even on fresh starts (subprocess recovery)
self._load_bad_configs()
self._init_search()
try:
best = self._autotune()
Expand Down Expand Up @@ -1296,6 +1348,11 @@ def _cleanup_checkpoint(self) -> None:
checkpoint_file.unlink()
self.log(f"Checkpoint cleaned up: {checkpoint_file}")

# Clean up subprocess recovery artifacts
from .subprocess_runner import cleanup_subprocess_artifacts

cleanup_subprocess_artifacts(checkpoint_dir_str)

@staticmethod
def _serialize_numpy_rng_state(
state: tuple[str, Any, int, int, float],
Expand Down
60 changes: 60 additions & 0 deletions helion/autotuner/subprocess_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""File I/O helpers for autotuner crash recovery.

The crash recovery protocol works with an external retry loop
(scripts/autotune_with_crash_recovery.sh). Before benchmarking each
config, the autotuner writes its string representation to a pending
file. If the process crashes (e.g. CUDA illegal memory access), the
pending file survives and the external retry loop records it as a bad
config. On re-run, the autotuner loads the checkpoint + bad configs
and skips the poison config.
"""

from __future__ import annotations

import os
from pathlib import Path

_PENDING_FILENAME = "_pending_config.txt"
_BAD_CONFIGS_FILENAME = "_bad_configs.txt"


def write_pending(checkpoint_dir: str, config_str: str) -> None:
"""Write the config being benchmarked to the pending file."""
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
pending_path.write_text(config_str)


def clear_pending(checkpoint_dir: str) -> None:
"""Remove the pending file after benchmark completes."""
pending_path = Path(checkpoint_dir) / _PENDING_FILENAME
if pending_path.exists():
pending_path.unlink()


def load_bad_configs(bad_configs_path: str) -> set[str]:
"""Load bad config strings from file, one per line."""
path = Path(bad_configs_path)
if not path.exists():
return set()
lines = path.read_text().splitlines()
return {line.strip() for line in lines if line.strip()}


def _append_bad_config(bad_configs_path: str, config_str: str) -> None:
"""Append a bad config string to the bad configs file."""
with open(bad_configs_path, "a") as f:
f.write(config_str + "\n")
f.flush()
os.fsync(f.fileno())


def cleanup_subprocess_artifacts(checkpoint_dir: str) -> None:
"""Remove crash-recovery files in the checkpoint directory."""
checkpoint_path = Path(checkpoint_dir)
for name in (
_PENDING_FILENAME,
_BAD_CONFIGS_FILENAME,
):
artifact = checkpoint_path / name
if artifact.exists():
artifact.unlink()
111 changes: 111 additions & 0 deletions scripts/autotune_with_crash_recovery.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env bash
# Autotuner crash recovery wrapper.
#
# Runs a command (typically a Python script that calls helion autotuning)
# in a retry loop. When the process crashes due to an unrecoverable CUDA
# error (illegal memory access, misaligned address, etc.), the autotuner
# leaves a "_pending_config.txt" breadcrumb in the checkpoint directory.
# This script detects that file, records the poison config in
# "_bad_configs.txt", and re-runs the command. On re-run the autotuner
# loads its checkpoint and skips the bad config.
#
# Progress detection:
# Each crash should block a different config (since blocked configs are
# skipped on re-run). If the same config crashes twice, the autotuner
# is stuck and we give up.
#
# Requirements:
# - HELION_AUTOTUNE_CHECKPOINT_DIR must be set
#
# Usage:
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/ckpt \
# scripts/autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
#
# Examples:
# HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/autotune_ckpt \
# scripts/autotune_with_crash_recovery.sh -- python train.py

set -uo pipefail

# --- Argument parsing ---
usage() {
cat >&2 <<'EOF'
Usage: HELION_AUTOTUNE_CHECKPOINT_DIR=/path/to/dir \
autotune_with_crash_recovery.sh -- COMMAND [ARGS...]
EOF
exit "${1:-1}"
}

while [[ $# -gt 0 ]]; do
case "$1" in
-h|--help)
usage 0
;;
--)
shift
break
;;
*)
echo "Error: unknown option '$1'" >&2
usage 1
;;
esac
done

if [[ $# -eq 0 ]]; then
echo "Error: no command specified after --" >&2
usage 1
fi

if [[ -z "${HELION_AUTOTUNE_CHECKPOINT_DIR:-}" ]]; then
echo "Error: HELION_AUTOTUNE_CHECKPOINT_DIR must be set." >&2
exit 1
fi

# --- Setup ---
checkpoint_dir="$HELION_AUTOTUNE_CHECKPOINT_DIR"
mkdir -p "$checkpoint_dir"

pending_file="$checkpoint_dir/_pending_config.txt"
bad_configs_file="$checkpoint_dir/_bad_configs.txt"

# --- Retry loop ---
attempt=0
last_config=""

while true; do
attempt=$((attempt + 1))

# Run the user command (don't use set -e, capture exit code manually)
"$@"
exit_code=$?

if [[ $exit_code -eq 0 ]]; then
exit 0
fi

# Check if the autotuner left a pending config breadcrumb
if [[ -f "$pending_file" ]]; then
config=$(cat "$pending_file")
rm -f "$pending_file"
echo "$config" >> "$bad_configs_file"

echo "[crash-recovery] Process crashed (exit code $exit_code, attempt $attempt)." >&2
echo "[crash-recovery] Blocked config: $config" >&2

# If the same config crashed again, the bad config is not being
# skipped — the autotuner is stuck.
if [[ "$config" == "$last_config" ]]; then
echo "[crash-recovery] Same config crashed twice — the autotuner appears stuck." >&2
echo "[crash-recovery] All bad configs have been recorded. You can re-run this script and it will resume from the latest checkpoint, skipping all previously recorded bad configs." >&2
exit 1
fi
last_config="$config"

echo "[crash-recovery] Restarting from checkpoint..." >&2
else
# No pending file — this is not a recoverable CUDA crash.
# Propagate the original exit code.
exit "$exit_code"
fi
done
67 changes: 67 additions & 0 deletions test/data/autotune_crash_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Helper script for bash crash recovery tests.

Run via:
HELION_AUTOTUNE_CHECKPOINT_DIR=DIR \
scripts/autotune_with_crash_recovery.sh -- python test/data/autotune_crash_helper.py

On first run (when _CRASH_ON_FIRST_BENCHMARK is set and no counter file
exists): patches do_bench to trigger a real CUDA illegal memory access,
which exercises the real _pending_config context manager and
TritonUnrecoverableRuntimeError code path. On subsequent runs: autotuning
resumes from checkpoint normally, skipping the bad config.

Without _CRASH_ON_FIRST_BENCHMARK: runs autotuning normally (used to test
that the bash script passes through a successful run).
"""

from __future__ import annotations

import os
from pathlib import Path

import torch

checkpoint_dir = os.environ["HELION_AUTOTUNE_CHECKPOINT_DIR"]
crash_on_first = os.environ.get("_CRASH_ON_FIRST_BENCHMARK", "")
counter_file = Path(checkpoint_dir) / "_benchmark_counter"

if crash_on_first and not counter_file.exists():
import triton
import triton.language as tl

import helion.autotuner.base_search as _bs

@triton.jit
def _ima_kernel(ptr):
"""Triton kernel that triggers illegal memory access."""
bad_ptr = ptr + (1 << 40)
tl.store(bad_ptr, tl.full([], 42.0, dtype=tl.float32))

_original_do_bench = _bs.do_bench

def _ima_do_bench(*args, **kwargs): # type: ignore[no-untyped-def]
counter_file.write_text("done")
# Restore original so this only fires once
_bs.do_bench = _original_do_bench
# Trigger real CUDA illegal memory access
x = torch.zeros(1, device="cuda")
_ima_kernel[(1,)](x)
torch.cuda.synchronize()
# Should not reach here — IMA raises an exception
return _original_do_bench(*args, **kwargs)

_bs.do_bench = _ima_do_bench

# Import and run real autotuning
from helion._testing import import_path # noqa: E402

datadir = Path(__file__).parent
basic_kernels = import_path(datadir / "basic_kernels.py")

args = (torch.randn([8, 32], device="cuda"), torch.randn([8, 32], device="cuda"))
bound = basic_kernels.add.bind(args)
bound.settings.autotune_checkpoint_dir = checkpoint_dir
bound.settings.autotune_effort = "quick"
config = bound.autotune(args, force=True)
result = bound(*args)
torch.testing.assert_close(result, args[0] + args[1])
Loading
Loading