Skip to content

[Autotuner] Adding LLM-guided search#2003

Draft
choijon5 wants to merge 1 commit intomainfrom
choijon5/stack/3
Draft

[Autotuner] Adding LLM-guided search#2003
choijon5 wants to merge 1 commit intomainfrom
choijon5/stack/3

Conversation

@choijon5
Copy link
Copy Markdown
Contributor

@choijon5 choijon5 commented Apr 12, 2026

Stacked PRs:


[Autotuner] Adding LLM-guided search

Prompting an LLM to get the configs:

image

Ln indicates the number of rounds of LLM prompts.
Ln_spd indicates the speedup of nth round of LLM vs LFBO full autotuning.
Ln_t1/LF compares the wall clock time it takes to finish the nth round of LLM vs LFBO full autotuning.

For simple kernels (matmul, layer_norm), LLM is able to one shot the config with on par perf as LFBO full autotuning at a tiny fraction (10s, 5-8%) of LFBO full autotuning time.
For more complex kernels (attention), LLM does not get to on par perf after 3 rounds, even though perf improves after each round. 3 LLM rounds still only take 14% of LFBO full autotuning time.
There are cases (cross_entropy) where perf does not improve after more LLM rounds.

The LLM prompts (and their responses) are shown in the comments below, although the prompts are still changing.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 12, 2026
@choijon5 choijon5 force-pushed the choijon5/stack/3 branch 4 times, most recently from 4b0c90d to 0a73757 Compare April 12, 2026 10:11
@choijon5
Copy link
Copy Markdown
Contributor Author

choijon5 commented Apr 12, 2026

You are an expert GPU kernel autotuner for Helion/Triton kernels.

Use the provided Configuration Space and Default Configuration as the source of truth for:

  • allowed field names and enum values
  • which fields are scalar vs list-valued
  • required list lengths
  • valid ranges and defaults

General heuristics:

  • analyze the kernel source, input tensors, GPU hardware, and config space to infer likely optimization
    traits from the code itself and target hardware; if unsure, stay closer to default.
  • block_sizes and num_warps should be powers of 2 when present.
  • persistent pid_type is often worth trying when total tile count is comparable to or larger than SM count,
    and it may also be required for some kernels.
  • tensor_descriptor is a distinct family from pointer indexing.
  • higher num_stages and multi-buffering are more aggressive and should be used selectively.

Output contract:

  • Return minified JSON on a single line. No markdown, code fences, comments, pretty-printing, or trailing
    commas.
  • Emit exactly one top-level object: {"configs":[...]} and make every config unique.
  • Do not use Python syntax or expressions such as single-quoted strings or list multiplication like
    ["pointer"] * 4.
  • Only specify fields you want to change; unspecified = default.
  • Use only field names and enum values that appear in the config space.
  • For list-valued fields, emit an explicit JSON array with the exact required length shown in the config
    space.
  • Never use a scalar as shorthand for a list-valued field, and never wrap scalar-valued fields in single-
    element lists.
  • If you are unsure about a field's structure, required list length, or allowed values, omit that field
    instead of guessing.
  • Use null not None, true/false not True/False.
  • Return ONLY minified JSON: {"configs":[...]}

User Prompt

Kernel Source Code

def attention(
    q_in: torch.Tensor,
    k_in: torch.Tensor,
    v_in: torch.Tensor,
) -> torch.Tensor:
    """
    Computes scaled dot-product attention.

    Implements the attention mechanism: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

    Args:
        q_in: Query tensor of shape [..., seq_len_q, head_dim]
        k_in: Key tensor of shape [..., seq_len_k, head_dim]
        v_in: Value tensor of shape [..., seq_len_k, head_dim]

    Returns:
        Output tensor of shape [..., seq_len_q, head_dim]
    """
    m_dim = q_in.size(-2)
    n_dim = k_in.size(-2)
    assert n_dim == v_in.size(-2)
    head_dim = hl.specialize(q_in.size(-1))
    assert head_dim == k_in.size(-1) == v_in.size(-1)
    q_view = q_in.reshape([-1, m_dim, head_dim])
    v_view = v_in.reshape([-1, n_dim, head_dim])
    k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
    out = torch.empty_like(q_view)
    sm_scale = 1.0 / math.sqrt(head_dim)
    qk_scale = sm_scale * 1.44269504  # 1/log(2)
    for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
        m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
        l_i = torch.full_like(m_i, 1.0)
        acc = hl.zeros([tile_b, tile_m, head_dim], dtype=torch.float32)
        q = q_view[tile_b, tile_m, :]
        for tile_n in hl.tile(v_view.size(1)):
            k = k_view[tile_b, :, tile_n]
            qk = torch.bmm(q, k)
            m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, :, None]
            p = torch.exp2(qk)
            l_ij = torch.sum(p, -1)
            alpha = torch.exp2(m_i - m_ij)
            l_i = l_i * alpha + l_ij
            acc = acc * alpha[:, :, None]
            v = v_view[tile_b, tile_n, :]
            p = p.to(v.dtype)
            acc = torch.baddbmm(acc, p, v)
            m_i = m_ij
        m_i += torch.log2(l_i)
        acc = acc / l_i[:, :, None]
        out[tile_b, tile_m, :] = acc.to(out.dtype)
    return out.view(q_in.size())

Input Tensors

arg[0]: shape=[2, 32, 1024, 64], dtype=torch.float16
arg[1]: shape=[2, 32, 1024, 64], dtype=torch.float16
arg[2]: shape=[2, 32, 1024, 64], dtype=torch.float16

GPU Hardware

Device: NVIDIA B200
Compute units (SMs): 148
Total memory: 178.4 GB
Max threads per SM: 2048

Configuration Space

block_sizes: [power_of_2(min=1, max=1, default=1), power_of_2(min=8, max=1024, default=16),

power_of_2(min=16, max=1024, default=16)]
loop_orders: list of 1 x permutation(length=2)
l2_groupings: list of 1 x power_of_2(min=1, max=64, default=1)
range_unroll_factors: [integer(min=0, max=4, default=0), integer(min=0, max=4, default=0)]
range_warp_specializes: [enum(None, False, True), enum(None, False, True)]
range_num_stages: [integer(min=0, max=4, default=0), integer(min=0, max=4, default=0)]
range_multi_buffers: [enum(None, False, True), enum(None, False, True)]
range_flattens: [enum(None, False, True), enum(None, False, True)]
num_warps: power_of_2(min=1, max=32, default=4)
num_stages: integer(min=1, max=8, default=1)
indexing: list of 4 x enum('pointer', 'tensor_descriptor')
pid_type: enum('flat', 'persistent_blocked', 'persistent_interleaved')
num_sm_multiplier: power_of_2(min=1, max=128, default=1)
load_eviction_policies: list of 3 x enum('', 'first', 'last')
maxnreg: enum(None, 32, 64, 128, 256)

Default Configuration

{"block_sizes":[1,16,16],"indexing":["pointer","pointer","pointer","pointer"],"l2_groupings":
[1],"load_eviction_policies":["","",""],"loop_orders":
[[0,1]],"num_stages":1,"num_warps":4,"pid_type":"flat","range_flattens":[null,null],"range_multi_buffers":
[null,null],"range_num_stages":[0,0],"range_unroll_factors":[0,0],"range_warp_specializes":[null,null]}

Search Strategy

- Generate up to 15 UNIQUE candidate configs. Fewer is better than invalid JSON.
- First analyze the kernel source, input tensors, GPU hardware, and config space. Infer likely tiling/

reuse, accumulation, memory-vs-compute, and scheduling traits from the code itself and target hardware,
then choose config families accordingly.
- Cover 3 config families with a rough mix of about 40% near-default safe, 40% balanced throughput, and
20% aggressive configs, while keeping most candidates valid and compilable.
- If the kernel structure is unclear, stay closer to default and avoid aggressive coupled changes.
- Keep each config sparse: usually 2-6 changed fields, omit unchanged defaults, and exceed 6 only when
several coupled changes are needed for a distinct family.
- Make block_sizes materially different across configs rather than emitting near-duplicates.
- Do not guess field structure: for list-valued fields, emit an explicit JSON array of the exact required
length; if that length is unclear, omit the field.
- Avoid configs that simultaneously max out several aggressive knobs such as num_warps, num_stages, and
maxnreg when present, unless strongly justified.
- Compile timeout is 15s, so avoid candidates that are likely to compile very slowly.
- If tensor_descriptor is available, treat it as a separate family: include a few configs using it, but
keep some pure pointer configs too.
- Include both flat and persistent scheduling families when plausible; do not put every config on the
same pid_type.

Task

Suggest the first batch of configs. Include both near-default and exploratory candidates. Return minified
JSON only: {"configs":[...]}

@choijon5
Copy link
Copy Markdown
Contributor Author

The LLM response with:

Round 0 Response

{
"configs": [
{
"block_sizes": [1, 32, 64]
},
{
"block_sizes": [1, 64, 64],
"num_warps": 8
},
{
"block_sizes": [1, 64, 128],
"num_warps": 8
},
{
"block_sizes": [1, 128, 64],
"num_warps": 8
},
{
"block_sizes": [1, 64, 64],
"num_warps": 4,
"l2_groupings": [4]
},
{
"block_sizes": [1, 64, 64],
"num_warps": 4,
"l2_groupings": [8]
},
{
"block_sizes": [1, 64, 64],
"pid_type": "persistent_blocked",
"num_sm_multiplier": 2
},
{
"block_sizes": [1, 64, 64],
"pid_type": "persistent_interleaved",
"num_sm_multiplier": 2
},
{
"block_sizes": [1, 64, 64],
"pid_type": "persistent_blocked",
"num_sm_multiplier": 4,
"l2_groupings": [4]
},
{
"block_sizes": [1, 64, 64],
"num_warps": 8,
"num_stages": 2
},
{
"block_sizes": [1, 64, 64],
"num_warps": 8,
"num_stages": 3,
"maxnreg": 128
},
{
"block_sizes": [1, 64, 64],
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor"
],
"num_warps": 8
},
{
"block_sizes": [1, 64, 64],
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor"
],
"pid_type": "persistent_blocked",
"num_sm_multiplier": 2
},
{
"block_sizes": [1, 64, 64],
"num_warps": 8,
"range_multi_buffers": [true, null],
"range_num_stages": [2, 0]
},
{
"block_sizes": [1, 64, 64],
"num_warps": 8,
"range_warp_specializes": [true, null],
"range_unroll_factors": [2, 0]
}
]
}

@choijon5
Copy link
Copy Markdown
Contributor Author

The next iteration of LLM prompt:

Round 1 Refinement Prompt

Search State

Best so far: 0.0758 ms — {"block_sizes": [1, 64, 64], "l2_groupings": [4], "num_sm_multiplier": 4, "pid_type": "persistent_blocked"}
Margin vs runner-up: 0.0%
Search coverage: 17 successful / 24 total configs
Failed configs so far: 7

Anchor Configs

Anchor 1 (best): 0.0758 ms — {"block_sizes": [1, 64, 64], "l2_groupings": [4], "num_sm_multiplier": 4, "pid_type":

"persistent_blocked"}
Anchor 2 (+0.0%): 0.0758 ms — {"block_sizes": [1, 64, 64], "num_stages": 3, "num_warps": 8}

Results (best first)

#1: 0.0758 ms — {"block_sizes": [1, 64, 64], "l2_groupings": [4], "num_sm_multiplier": 4, "pid_type": "persistent_blocked"}
#2: 0.0758 ms — {"block_sizes": [1, 64, 64], "num_stages": 3, "num_warps": 8}
#3: 0.0778 ms — {"block_sizes": [1, 64, 64], "l2_groupings": [4]}
#4: 0.0778 ms — {"block_sizes": [1, 64, 64], "l2_groupings": [8]}
#5: 0.0819 ms — {"block_sizes": [1, 128, 64], "num_warps": 8}
#6: 0.0881 ms — {"block_sizes": [1, 64, 128], "num_warps": 8}
#7: 0.0964 ms — {"block_sizes": [1, 64, 64], "num_warps": 8}
#8: 0.0964 ms — {"block_sizes": [1, 64, 64], "indexing": ["tensor_descriptor", "tensor_descriptor", "tensor_descriptor",

"tensor_descriptor"], "num_warps": 8}
#9: 0.0964 ms — {"block_sizes": [1, 64, 64], "num_stages": 2, "num_warps": 8}
#10: 0.1065 ms — {"block_sizes": [1, 64, 64], "num_sm_multiplier": 2, "pid_type": "persistent_interleaved"}
#11: 0.1249 ms — {"block_sizes": [1, 64, 64], "num_sm_multiplier": 2, "pid_type": "persistent_blocked"}
#12: 0.1249 ms — {"block_sizes": [1, 64, 64], "indexing": ["tensor_descriptor", "tensor_descriptor", "tensor_descriptor",
"tensor_descriptor"], "num_sm_multiplier": 2, "pid_type": "persistent_blocked"}
(7 configs failed to compile or had errors)

Top Config Patterns

block_sizes: mostly [1,64,64] (also [1,128,64] x1)
l2_groupings: [4] x2, [8] x1
num_sm_multiplier: always 4
num_stages: always 3
num_warps: always 8
pid_type: always "persistent_blocked"

Failed Config Patterns

Counts: error=7
error: {"block_sizes": [1, 64, 16], "indexing": ["pointer", "pointer", "pointer", "tensor_descriptor"], "load_eviction_policies":

["", "last", "last"], "loop_orders": [[1, 0]], "num_stages": 8, "num_warps": 16, "range_flattens": [null, false],
"range_multi_buffers": [null, false], "range_num_stages": [0, 2], "range_warp_specializes": [null, true]}
error: {"block_sizes": [1, 16, 32], "indexing": ["pointer", "pointer", "pointer", "tensor_descriptor"], "l2_groupings": [16],
"load_eviction_policies": ["last", "first", "last"], "maxnreg": 32, "num_sm_multiplier": 16, "num_stages": 3, "num_warps": 8,
"pid_type": "persistent_blocked", "range_flattens": [null, false], "range_multi_buffers": [null, true], "range_num_stages": [4, 0],
"range_unroll_factors": [0, 2], "range_warp_specializes": [true, null]}
error: {"indexing": ["tensor_descriptor", "tensor_descriptor", "pointer", "tensor_descriptor"], "l2_groupings": [8],
"load_eviction_policies": ["last", "last", ""], "loop_orders": [[1, 0]], "maxnreg": 32, "num_sm_multiplier": 64, "num_stages": 7,
"pid_type": "persistent_interleaved", "range_flattens": [null, true], "range_num_stages": [4, 4], "range_unroll_factors": [1, 2],
"range_warp_specializes": [true, null]}
error: {"block_sizes": [1, 32, 128], "indexing": ["tensor_descriptor", "pointer", "tensor_descriptor", "tensor_descriptor"],
"l2_groupings": [4], "load_eviction_policies": ["first", "last", "last"], "loop_orders": [[1, 0]], "num_stages": 8, "range_flattens":
[null, true], "range_multi_buffers": [null, true], "range_unroll_factors": [0, 3], "range_warp_specializes": [null, false]}

Next Step

- About two thirds of configs should be 1-field mutations of Anchor 1.
- Use most of the rest for 1-2 field mutations of Anchor 2.
- Reserve at most a small minority for one clearly different family, not random noise.
- Prefer edits with attributable effects: change block_sizes, num_warps, num_stages, pid_type, indexing, l2_groupings, or maxnreg

instead of rewriting every field.
- Keep each config sparse: usually 1-4 changed fields, and no more than 6 unless absolutely necessary.
- Do not guess shapes: list-valued fields need explicit arrays of the exact required length, and if that length is unclear you should
omit the field.
- Keep compile cost in mind: avoid candidates that are likely to exceed the 15s compile timeout.
- If unsure, return fewer valid configs instead of verbose or malformed JSON.

Task

Suggest up to 15 NEW UNIQUE configs around the anchors above. Avoid the failed patterns above and favor targeted edits with
attributable effects. Return minified JSON only: {"configs":[...]}

choijon5 added a commit that referenced this pull request Apr 13, 2026
stack-info: PR: #2003, branch: choijon5/stack/3
stack-info: PR: #2003, branch: choijon5/stack/3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant