Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
Each preset also sets a default initial population strategy (see :doc:`../deployment_autotuning` for details).
Users can still override individual ``autotune_*`` settings; explicit values win over the preset. Controlled by ``HELION_AUTOTUNE_EFFORT``.

.. autoattribute:: Settings.autotune_checkpoint_dir

Directory path for saving and resuming autotuning checkpoints. When set, the autotuner
saves in-progress state to ``{dir}/{stable_hash}.pt`` and auto-discovers matching
checkpoints on subsequent runs. The checkpoint file is deleted on successful completion.
When unset (default), no checkpoints are saved or loaded (opt-in).
Controlled by ``HELION_AUTOTUNE_CHECKPOINT_DIR``.

.. autoattribute:: Settings.autotune_best_available_max_configs

Maximum number of cached configs to use when seeding the initial population with the ``from_best_available`` strategy.
Expand Down Expand Up @@ -323,6 +331,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"LFBOTreeSearch"`` (default),
| ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. |
| ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. |
| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. |
| ``HELION_AUTOTUNE_CHECKPOINT_DIR`` | ``autotune_checkpoint_dir`` | Directory path for saving/resuming autotuning checkpoints (opt-in). |
| ``TRITON_STORE_BINARY_ONLY`` | Triton (autotuning) | Set to ``1`` during autotuning to skip Triton intermediate IRs, reducing cache size ~40%. Set to ``0`` to retain IRs for debugging. |
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, skip both reading and writing the autotuning cache entirely. |
Expand Down
23 changes: 23 additions & 0 deletions docs/deployment_autotuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ Related settings for `from_best_available` (see {doc}`api/settings`):
| `autotune_best_available_max_configs` | `HELION_BEST_AVAILABLE_MAX_CONFIGS` | 20 | Maximum cached configs to seed |
| `autotune_best_available_max_cache_scan` | `HELION_BEST_AVAILABLE_MAX_CACHE_SCAN` | 500 | Maximum cache files to scan |

### Checkpointing Long-Running Autotuning

For very long autotuning sessions, you can save and resume state using
checkpoints. This is useful when tuning might be interrupted (e.g., preemptible
instances) or when you want to continue tuning from a previous unfinished run.

Set the `HELION_AUTOTUNE_CHECKPOINT_DIR` environment variable to a directory
path. The autotuner will periodically save checkpoints there, keyed by the
kernel's stable hash. If interrupted, re-run with the same directory to resume
automatically. On successful completion, the checkpoint file is cleaned up.

```bash
# Enable checkpointing to a directory:
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py

# If interrupted, just re-run with the same directory to resume:
HELION_AUTOTUNE_CHECKPOINT_DIR=/tmp/helion_checkpoints python run_kernel.py
```

Without `HELION_AUTOTUNE_CHECKPOINT_DIR`, no checkpoints are saved (opt-in).
Multiple kernels can safely use the same directory — each kernel writes to a
file named by its unique stable hash.

## Deploy a Single Config

If one configuration wins for every production call, bake it into the decorator:
Expand Down
22 changes: 22 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import operator
import os
from pathlib import Path
import random
import re
import sys
from typing import TYPE_CHECKING
Expand All @@ -19,6 +20,7 @@
from typing import cast
import unittest

import numpy as np
import pytest
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -55,6 +57,26 @@
from .runtime.kernel import Kernel


def seed_rng(seed: int) -> None:
random.seed(seed)
np.random.seed(seed) # noqa: NPY002
torch.manual_seed(seed)


@contextlib.contextmanager
def fork_rng() -> Generator[None, None, None]:
"""Context manager that forks all RNGs and restores original state on exit."""
python_state = random.getstate()
numpy_state = np.random.get_state() # noqa: NPY002

with torch.random.fork_rng():
try:
yield
finally:
random.setstate(python_state)
np.random.set_state(numpy_state) # noqa: NPY002


def _strip_launcher_args(value: str) -> str:
strip_pairs = []
if supports_amd_cdna_tunables():
Expand Down
Loading
Loading