Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .external import autotune as autotune
from .finite_search import FiniteSearch as FiniteSearch
from .llm_search import LLMGuidedSearch as LLMGuidedSearch
from .llm_seeded_lfbo import LLMSeededLFBOTreeSearch as LLMSeededLFBOTreeSearch
from .llm_seeded_lfbo import LLMSeededSearch as LLMSeededSearch
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
from .pattern_search import InitialPopulationStrategy as InitialPopulationStrategy
Expand All @@ -38,6 +40,8 @@
"LFBOPatternSearch": LFBOPatternSearch,
"LFBOTreeSearch": LFBOTreeSearch,
"LLMGuidedSearch": LLMGuidedSearch,
"LLMSeededSearch": LLMSeededSearch,
"LLMSeededLFBOTreeSearch": LLMSeededLFBOTreeSearch,
"DifferentialEvolutionSearch": DifferentialEvolutionSearch,
"FiniteSearch": FiniteSearch,
"PatternSearch": PatternSearch,
Expand Down
32 changes: 26 additions & 6 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def __init__(
super().__init__(kernel, args)
self.finishing_rounds = finishing_rounds
self.population: list[PopulationMember] = []
self._best_available_seed_configs: list[Config] = []
self.config_gen: ConfigGeneration = self.config_spec.create_config_generation(
overrides=self.settings.autotune_config_overrides or None,
advanced_controls_files=self.settings.autotune_search_acf or None,
Expand Down Expand Up @@ -856,15 +857,20 @@ def _find_similar_cached_configs(self, max_configs: int) -> list[SavedBestConfig

def _generate_best_available_population_flat(self) -> list[FlatConfig]:
"""
Generate initial population using default config plus cached configs.
Generate initial population using default config, explicit seed configs,
and cached configs.

Always starts with the default configuration, then adds up to
MAX_BEST_AVAILABLE_CONFIGS matching cached configs from previous runs.
No random configs are added. Duplicate configs are discarded.
Explicit seed configs provided by the caller are added ahead of cached
configs and are not suppressed by cache-skip settings. No random configs
are added. Duplicate configs are discarded.

Returns:
A list of unique FlatConfig values for the initial population.
Minimum size is 1 (just default), maximum is 1 + autotune_best_available_max_configs setting.
Minimum size is 1 (just default), plus any valid unique explicit
seed configs and up to autotune_best_available_max_configs cached
configs.
"""
# Always start with the default config
default_flat = self.config_gen.default_flat()
Expand All @@ -873,6 +879,16 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]:
result: list[FlatConfig] = [default_flat]
self.log("Starting with default config")

for config in self._best_available_seed_configs:
try:
flat = self.config_gen.flatten(config)
transferred_config = self.config_gen.unflatten(flat)
if transferred_config not in seen:
seen.add(transferred_config)
result.append(flat)
except (ValueError, TypeError, KeyError, AssertionError) as e:
self.log(f"Failed to transfer explicit seed config: {e}")

max_configs = self.settings.autotune_best_available_max_configs
cached_entries = self._find_similar_cached_configs(max_configs)

Expand Down Expand Up @@ -905,12 +921,16 @@ def _generate_best_available_population_flat(self) -> list[FlatConfig]:
if duplicates > 0:
self.log.debug(f"Discarded {duplicates} duplicate config(s)")

self.log(
f"Initial population: 1 default + {len(result) - 1} unique cached = {len(result)} total"
)
self.log(f"Initial population: {len(result)} total")

return result

def set_best_available_seed_configs(
self,
configs: Sequence[Config],
) -> None:
self._best_available_seed_configs = list(configs)

def parallel_benchmark_population(
self, members: list[PopulationMember], *, desc: str = "Benchmarking"
) -> list[PopulationMember]:
Expand Down
11 changes: 11 additions & 0 deletions helion/autotuner/block_id_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def _flat_config(
) -> object:
return fn(self._fragment(base))

def _encode_flat_value(self, base: ConfigSpec, value: object) -> object:
"""Encode a normalized Config value into its flat-slot representation.

Most specs store the same value in Config and FlatConfig, so the
default implementation is the identity. ReductionLoopSpec is the
only override today: it normalizes persistent reductions to None in
Config, but FlatConfig stores that choice as an integer sentinel.
"""
del base
return value


_BlockIdItemT = TypeVar("_BlockIdItemT", bound=_BlockIdItem)

Expand Down
13 changes: 11 additions & 2 deletions helion/autotuner/config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast

from .._compat import warps_to_threads
from .block_id_sequence import BlockIdSequence
from .config_fragment import Category
from .config_fragment import ConfigSpecFragment
from .config_fragment import PowerOfTwoFragment
Expand Down Expand Up @@ -117,14 +118,22 @@ def _apply_overrides(self, config: Config) -> Config:
def flatten(self, config: Config) -> FlatConfig:
"""Inverse of unflatten: convert a Config to a FlatConfig."""
result = self.default_flat()
flat_fields = self.config_spec._flat_fields()
for key, (indices, is_sequence) in self._key_to_flat_indices.items():
if key not in config.config:
continue
value = config.config[key]
if is_sequence:
assert isinstance(value, list)
for idx, v in zip(indices, value, strict=True):
result[idx] = v
field = flat_fields[key]
assert isinstance(field, BlockIdSequence)
# Sequence specs can normalize values in Config differently
# from how they are stored in FlatConfig. Only
# ReductionLoopSpec overrides this today, but keep the dispatch
# on the spec so flatten() remains the generic inverse of
# unflatten().
for idx, spec, v in zip(indices, field, value, strict=True):
result[idx] = spec._encode_flat_value(self.config_spec, v)
else:
assert len(indices) == 1
result[indices[0]] = value
Expand Down
30 changes: 24 additions & 6 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,9 +1027,10 @@ def __init__(
super().__init__([block_id])
self.size_hint = size_hint

def _flat_config(
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
) -> int | None:
def _flat_fragment(self, base: ConfigSpec) -> BlockSizeFragment:
# Shared by both directions:
# - unflatten: flat integer -> Config value via _flat_config()
# - flatten: Config value -> flat integer via _encode_flat_value()
low = 8 # TODO(jansel): is smaller needed?
high = next_power_of_2(max(low, self.size_hint))
default = min(high, 4096)
Expand All @@ -1038,16 +1039,33 @@ def _flat_config(
if base.max_reduction_threads is not None:
if self.size_hint > base.max_reduction_threads:
default = min(default, base.max_reduction_threads)
value = fn(BlockSizeFragment(low, high, default))
return BlockSizeFragment(low, high, default)

def _flat_config(
self, base: ConfigSpec, fn: Callable[[ConfigSpecFragment], object]
) -> int | None:
fragment = self._flat_fragment(base)
value = fn(fragment)
assert isinstance(value, int)
if not (low <= value <= high):
if not (fragment.low <= value <= fragment.high):
raise InvalidConfig(
f"Invalid value for reduction loop {low} <= {value} <= {high}"
"Invalid value for reduction loop "
f"{fragment.low} <= {value} <= {fragment.high}"
)
if value >= self.size_hint:
return None # max size becomes persistent reduction
return value

def _encode_flat_value(self, base: ConfigSpec, value: object) -> object:
# None means "persistent reduction" in the normalized Config. In the
# flat search space that same choice is represented by an integer
# sentinel, typically the fragment default such as 1024 for a 1024-wide
# reduction. This is the one non-identity Config <-> FlatConfig
# mapping today.
if value is None:
return self._flat_fragment(base).default()
return value

def _normalize(self, name: str, value: object) -> int | None:
if value is None:
return None
Expand Down
Loading
Loading