Skip to content

same seq#10

Open
pathfinder-pf wants to merge 84 commits intomainfrom
feat/same_seq
Open

same seq#10
pathfinder-pf wants to merge 84 commits intomainfrom
feat/same_seq

Conversation

@pathfinder-pf
Copy link
Copy Markdown
Collaborator

@pathfinder-pf pathfinder-pf commented Mar 16, 2026

optim
async copy

Summary by CodeRabbit

  • New Features

    • Added a same-sequence chunked forward variant with an optimized VMEM/DMA buffering pipeline.
  • Improvements

    • Reference implementation now returns final per-sequence states by default.
  • Tests

    • Added and updated tests to exercise the new variant, including profiling, repeated-run validation, and deterministic input setups.

whz added 2 commits March 16, 2026 17:00
fix

fix

fix

fix

fix

fix

fix

optim

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

async copy

fix

fix

fix

use at

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the chunked forward pass for sequences of uniform length within the Pallas framework. The primary goal is to enhance performance by utilizing Pallas-specific features for conditional execution, explicit memory placement, and asynchronous data transfers. A new kernel is introduced that employs double buffering and asynchronous copies to efficiently prefetch input data, aiming to reduce memory bottlenecks and improve overall execution speed for this specific use case.

Highlights

  • Pallas Conditional Execution: Replaced jax.lax.cond with the Pallas-specific @pl.when decorator for conditional execution within the kernel, streamlining control flow.
  • Memory Space Specification: Explicitly assigned input tensors (k, v, h0, gk) to pltpu.VMEM using pl.BlockSpec, optimizing memory placement for Pallas operations.
  • New Optimized Kernel for Same-Length Sequences: Introduced _chunk_fwd_h_kernel_with_same_seq and its JIT-compiled wrapper chunk_fwd_h_kernel_with_same_seq to handle chunked forward passes for sequences of the same length. This kernel leverages asynchronous memory copies and double buffering for k, v, and gk to improve performance.
  • Asynchronous Copy Implementation: Implemented pltpu.make_async_copy with semaphores and scratch buffers within the new kernel to prefetch data for k, v, and gk tiles, reducing memory access latency.
  • Test Integration: Integrated the new chunk_fwd_h_kernel_with_same_seq into the existing test suite, ensuring it is used when cu_seqlens is not provided (indicating same-length sequences).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/ops/common/chunk_h.py
    • Replaced lax.cond with @pl.when for conditional state storage and final state writing.
    • Added memory_space=pltpu.VMEM to pl.BlockSpec for input tensors (k, v, h0, gk).
    • Introduced _chunk_fwd_h_kernel_with_same_seq and chunk_fwd_h_kernel_with_same_seq for optimized handling of same-length sequences.
    • Implemented asynchronous copies (pltpu.make_async_copy) with double buffering and semaphores for k, v, and gk within the new kernel.
    • Configured pl.pallas_call for the new kernel with appropriate grid specifications, scratch shapes including DMA semaphores, and compiler parameters.
  • tests/ops/gla/test.py
    • Added a new test file to specifically test and profile the chunk_fwd_h_kernel_with_same_seq function.
  • tests/ops/gla/test_pallas_chunk_fwd_h.py
    • Imported chunk_fwd_h_kernel_with_same_seq.
    • Modified the _run_pallas function to conditionally call chunk_fwd_h_kernel_with_same_seq when cu_seqlens is None, otherwise defaulting to chunk_fwd_h_kernel.
Activity
  • The author, pathfinder-pf, initiated this pull request with the title 'same seq' and description 'optim async copy', indicating an intent to optimize sequence processing using asynchronous copy mechanisms.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized kernel for sequences of the same length, leveraging asynchronous data copies for performance. The changes are generally in the right direction, but I've found a critical race condition in the new asynchronous copy logic that needs to be addressed. Additionally, a new test file appears to be a temporary profiling script with hardcoded paths and should be removed from this pull request. Please see my detailed comments below.

Comment on lines +259 to +331
copy_k0 = pltpu.make_async_copy(
k_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
k_scratch_ref.at[0],
local_copy_sem0,
)
copy_k0.start()

copy_v0 = pltpu.make_async_copy(
v_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
v_scratch_ref.at[0],
local_copy_sem1,
)
copy_v0.start()

if gk_ref is not None:
copy_gk0 = pltpu.make_async_copy(
gk_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
gk_scratch_ref.at[0],
local_copy_sem2,
)
copy_gk0.start()

def body(i_t, carry):
b_h = carry
buf = jnp.mod(i_t , 2)
next_buf = jnp.mod(i_t + 1, 2)
copy_k0.wait()
copy_v0.wait()
if gk_ref is not None:
copy_gk0.wait()
i_s = i_t // NTS
@pl.when((i_t % NTS) == 0)
def store_fn():
h_ref[0, i_s, 0] = b_h

@pl.when(i_t + 1 < NT)
def do_prefetch():
t0 = (i_t + 1) * BT
pl_dslice = pl.dslice(t0, BT)
copy_k0 = pltpu.make_async_copy(
k_ref.at[(0, 0, pl_dslice, slice(None))],
k_scratch_ref.at[next_buf],
local_copy_sem0,
)
copy_k0.start()
copy_v0 = pltpu.make_async_copy(
v_ref.at[(0, 0, pl_dslice, slice(None))],
v_scratch_ref.at[next_buf],
local_copy_sem1,
)
copy_v0.start()

if gk_ref is not None:
copy_gk0 = pltpu.make_async_copy(
gk_ref.at[(0, 0, pl_dslice, slice(None))],
gk_scratch_ref.at[next_buf],
local_copy_sem2,
)
copy_gk0.start()
k_tile = k_scratch_ref[buf]
v_tile = v_scratch_ref[buf]
if gk_ref is not None:
gk_tile = gk_scratch_ref[buf]
g_last = gk_tile[-1, :]
decay = jnp.exp(g_last)
b_h = b_h * decay[:, None] # [BK, BV] * [BK,1]
k_tile = (k_tile * jnp.exp(g_last[None, :] - gk_tile)).astype(gk_tile.dtype)

b_h = b_h + jax.lax.dot(k_tile.T, v_tile)

return b_h

b_h = lax.fori_loop(0, NT, body, b_h)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical race condition in the asynchronous data copy logic within the fori_loop. The async copy handles (copy_k0, copy_v0, copy_gk0) are defined outside the loop's body and are reassigned within the do_prefetch function. Due to Python's scoping rules, this creates new local variables within do_prefetch, and the outer handles are never updated. Consequently, the copy_*.wait() calls at the beginning of the loop body always wait on the handle for the very first chunk, not the one prefetched in the previous iteration. This can lead to using data that has not been copied yet.

To fix this, the async copy handles must be passed through the loop's carry state. This ensures that each iteration waits for the correct data and launches the prefetch for the next one. I've provided a suggestion that refactors the loop to correctly manage the handle state using lax.cond for the conditional prefetch.

    copy_k = pltpu.make_async_copy(
        k_ref.at[(0, 0,  pl.dslice(0, BT), slice(None))],
        k_scratch_ref.at[0],
        local_copy_sem0,
    )
    copy_k.start()

    copy_v = pltpu.make_async_copy(
        v_ref.at[(0, 0,  pl.dslice(0, BT), slice(None))],
        v_scratch_ref.at[0],
        local_copy_sem1,
    )
    copy_v.start()

    if gk_ref is not None:
        copy_gk = pltpu.make_async_copy(
            gk_ref.at[(0, 0,  pl.dslice(0, BT), slice(None))],
            gk_scratch_ref.at[0],
            local_copy_sem2,
        )
        copy_gk.start()
    else:
        copy_gk = None

    def body(i_t, carry):
        b_h, copy_k, copy_v, copy_gk = carry
        buf = jnp.mod(i_t, 2)

        copy_k.wait()
        copy_v.wait()
        if gk_ref is not None:
             copy_gk.wait()
        i_s = i_t // NTS
        @pl.when((i_t % NTS) == 0)
        def store_fn():
            h_ref[0, i_s, 0] = b_h
        
        def prefetch_fn(op):
            next_buf = jnp.mod(i_t + 1,  2)
            t0 = (i_t + 1) * BT
            pl_dslice = pl.dslice(t0, BT)
            next_copy_k = pltpu.make_async_copy(
                k_ref.at[(0, 0,  pl_dslice, slice(None))],
                k_scratch_ref.at[next_buf],
                local_copy_sem0,
            )
            next_copy_k.start()
            next_copy_v = pltpu.make_async_copy(
                v_ref.at[(0, 0, pl_dslice, slice(None))],
                v_scratch_ref.at[next_buf],
                local_copy_sem1,
            )
            next_copy_v.start()

            if gk_ref is not None:
                next_copy_gk = pltpu.make_async_copy(
                    gk_ref.at[(0, 0,  pl_dslice, slice(None))],
                    gk_scratch_ref.at[next_buf],
                    local_copy_sem2,
                )
                next_copy_gk.start()
            else:
                next_copy_gk = op[2]
            return next_copy_k, next_copy_v, next_copy_gk

        k_tile = k_scratch_ref[buf]
        v_tile = v_scratch_ref[buf]
        if gk_ref is not None:
            gk_tile = gk_scratch_ref[buf]
            g_last = gk_tile[-1, :]
            decay = jnp.exp(g_last)
            b_h = b_h * decay[:, None]  # [BK, BV] * [BK,1]
            k_tile = (k_tile * jnp.exp(g_last[None, :] - gk_tile)).astype(gk_tile.dtype)

        b_h = b_h + jax.lax.dot(k_tile.T, v_tile)

        next_copies = lax.cond(i_t + 1 < NT, prefetch_fn, lambda x: x, (copy_k, copy_v, copy_gk))
        
        return (b_h, *next_copies)

    b_h, _, _, _ = lax.fori_loop(0, NT, body, (b_h, copy_k, copy_v, copy_gk))

Comment on lines +1 to +60
import jax
B, T, H, K = 4, 1024, 4, 128
chunk_size = 64
N = 1
cu = None
if cu is not None:
N = len(cu) - 1
else:
N = B
key = jax.random.PRNGKey(1)
k = jax.random.normal(key, (B, T, H, K))
v = jax.random.normal(key, (B, T, H, K))
gk = jax.random.normal(key, (B, T, H, K))
h0 = jax.random.normal(key, (N, H, K, K))

from src.ops.common.chunk_h import chunk_fwd_h_kernel, chunk_fwd_h_kernel_with_same_seq

def _run_pallas(
k,
v,
gk=None,
h0=None,
chunk_size=64,
*,
cu_seqlens=None,
):
h, ht = chunk_fwd_h_kernel_with_same_seq(
k,
v,
gk=gk,
h0=h0,
chunk_size=chunk_size,
output_final_state=True,
)
return h, ht


pallas_h, pallas_ht = _run_pallas(
k,
v,
gk=gk,
h0=h0,
chunk_size=chunk_size,
cu_seqlens=cu,
)
jax.block_until_ready(pallas_h)
jax.block_until_ready(pallas_ht)
jax.profiler.start_trace("/home/gcpuser/profile")
for i in range(3):
pallas_h, pallas_ht = _run_pallas(
k,
v,
gk=gk,
h0=h0,
chunk_size=chunk_size,
cu_seqlens=cu,
)
jax.block_until_ready(pallas_h)
jax.block_until_ready(pallas_ht)
jax.profiler.stop_trace() No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new file appears to be a temporary profiling script rather than a formal test case. It contains several issues:

  • Hardcoded Path: jax.profiler.start_trace("/home/gcpuser/profile") uses a hardcoded, user-specific path that will not work for other developers or in CI environments.
  • Misleading Filename: The name test.py implies it contains unit tests with assertions, but it lacks any.
  • Dead/Unused Code: The cu variable is always None, and the cu_seqlens parameter in _run_pallas is unused.

This type of script is useful for development but should not be committed to the main repository as a test. Please remove this file from the pull request. If profiling capabilities are desired, they should be implemented as a proper benchmark or a separate, clearly-named script outside the test suite.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ops/common/chunk_h.py`:
- Around line 390-409: The allocation for number of split states (NS) uses floor
division and can be too small; change the calculation of NS from T // BS to use
ceil(T / BS) (e.g., math.ceil(T / BS) or equivalent integer-safe expression) so
it matches emitted outputs from the kernel (which writes at t_i = 0, NTS, 2*NTS,
...), and update any related buffers/allocations that reference NS (the buffers
used by h_out_scratch_ref, sems, or other arrays) so they are sized by the new
ceil-based NS; also apply the same fix to the other occurrence referenced in the
review (the similar block around the 520-530 region).
- Around line 466-480: The output_ht() callback unconditionally dereferences
ht_ref via ht_ref.at[...] which can be None when output_final_state=False
(out_specs has None); wrap the entire body of output_ht (all calls to
_async_copy inside output_ht) with a guard checking if ht_ref is not None before
performing any _async_copy to ht_ref.at, and apply the same pattern for the
similar optional-output block around lines 564–569 so no _async_copy targets
ht_ref when ht_ref is None.
- Around line 558-591: The DMA staging VMEM buffers (k_scratch, v_scratch,
gk_scratch, h_scratch, o_scratch, h_out_scratch, ht_out_scratch) are hard-coded
to jnp.float32 but must preserve the caller dtype to avoid illegal cross-dtype
transfers; change their allocation to use the originating dtype (e.g. k.dtype,
v.dtype, gk.dtype, h0.dtype, h_ref.dtype/ht_ref.dtype where appropriate) and
then perform an explicit jnp.astype(jnp.float32) inside the kernel before doing
exp/dot, and cast back to the original dtype only when writing h_ref/ht_ref back
to HBM; also confirm/adjust any make_async_copy usage to copy into the VMEM with
the matching dtype or add a comment guarding full-f32-only behavior if
cross-dtype DMA is unsupported.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4d49f045-f974-4fbd-8f5c-07c7a2ddc454

📥 Commits

Reviewing files that changed from the base of the PR and between feca212 and 6598d5a.

📒 Files selected for processing (1)
  • src/ops/common/chunk_h.py

Comment on lines +390 to +409
i_s = t_i // NTS
@pl.when((t_i % NTS) == 0)
def store_fn():
nonlocal h_o_buf, next_h_o_buf
@pl.when(h_o_buf > 0)
def wait_prev():
nonlocal h_o_buf, next_h_o_buf
pre_nts = i - NTS
b_pre_i, h_pre_i, k_pre_i, v_pre_i, t_pre_i = get_index(pre_nts)
k_pre_slice = pl.ds(k_pre_i * BK, BK)
v_pre_slice = pl.ds(v_pre_i * BV, BV)
b_pre_slice = b_part_i * local_B + b_pre_i
_async_copy(
h_out_scratch_ref.at[next_h_o_buf],
h_ref.at[b_pre_slice, t_pre_i // NTS, h_pre_i, k_pre_slice, v_pre_slice],
sems.at[4, next_h_o_buf],
True,
)
h_out_scratch_ref[h_o_buf] = o_scratch_ref[...]
_async_copy(h_out_scratch_ref.at[h_o_buf], h_ref.at[b_slice, i_s, h_i, k_slice, v_slice], sems.at[4, h_o_buf])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

NS must match the number of emitted split states.

This kernel stores at t_i = 0, NTS, 2 * NTS, ... per sequence, which is ceil(T / BS) outputs, but NS = T // BS only allocates floor(T / BS). Any trailing partial split will write h_ref[b_slice, i_s, ...] past the end.

🛡️ Minimal safe fix
     BS = BT if split_size is None else split_size
     assert BS % BT == 0, (
         f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
     )
+    assert T % BS == 0, (
+        f"The sequence length {T} must be divisible by `split_size` {BS} in the same-seq kernel."
+    )
     # N: the actual number of sequences in the batch with either equal or variable lengths
 
     N, NS = (
         B,
         T // BS,

If the trailing partial split is meant to be supported, switch this to math.ceil(T / BS) instead of floor division.

Also applies to: 520-530

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` around lines 390 - 409, The allocation for number
of split states (NS) uses floor division and can be too small; change the
calculation of NS from T // BS to use ceil(T / BS) (e.g., math.ceil(T / BS) or
equivalent integer-safe expression) so it matches emitted outputs from the
kernel (which writes at t_i = 0, NTS, 2*NTS, ...), and update any related
buffers/allocations that reference NS (the buffers used by h_out_scratch_ref,
sems, or other arrays) so they are sized by the new ceil-based NS; also apply
the same fix to the other occurrence referenced in the review (the similar block
around the 520-530 region).

Comment on lines +558 to +591
out_shape = [
jax.ShapeDtypeStruct(
shape=(N, NS, H, K, V), dtype=k.dtype if not states_in_fp32 else jnp.float32
)
]
out_specs = [pl.BlockSpec(memory_space=pltpu.HBM)]
if output_final_state:
out_shape.append(jax.ShapeDtypeStruct(shape=(N, H, K, V), dtype=k.dtype))
out_specs.append(pl.BlockSpec(memory_space=pltpu.HBM))
else:
out_shape.append(None)
out_specs.append(None)

in_specs = [
pl.BlockSpec(memory_space=pltpu.HBM),
pl.BlockSpec(memory_space=pltpu.HBM),
]
k_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
v_scratch = pltpu.VMEM((2, BT, BV), jnp.float32)
o_scratch = pltpu.VMEM((BK, BV), jnp.float32)
h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
ht_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
scratch_shapes = [k_scratch, v_scratch]
if h0 is not None:
in_specs.append(pl.BlockSpec(memory_space=pltpu.HBM))
h_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
scratch_shapes.append(h_scratch)
else:
scratch_shapes.append(None)
in_specs.append(None)
if gk is not None:
in_specs.append(pl.BlockSpec(memory_space=pltpu.HBM))
gk_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
scratch_shapes.append(gk_scratch)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's explore the repository structure
find . -type f -name "chunk_h.py" | head -20

Repository: primatrix/pallas-kernel

Length of output: 94


🏁 Script executed:

# Let's examine the specific file and lines mentioned
cat -n src/ops/common/chunk_h.py | sed -n '540,610p'

Repository: primatrix/pallas-kernel

Length of output: 3135


🏁 Script executed:

# Search for make_async_copy usage in this file to understand how DMAs are performed
rg "make_async_copy|async_copy" src/ops/common/chunk_h.py -B 3 -A 3

Repository: primatrix/pallas-kernel

Length of output: 4496


🏁 Script executed:

# Check the imports to understand what libraries are being used
head -50 src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 1342


🏁 Script executed:

# Find the kernel function definition and see how inputs are used
sed -n '200,350p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 3994


🏁 Script executed:

# Search for where k, v, gk, h0 dtypes are established in the outer function
rg "def.*chunk.*forward|def _chunk_fwd_h_main" src/ops/common/chunk_h.py -A 30 | head -60

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check if there's any dtype casting happening after DMA waits in the kernel
rg "astype|cast|dtype" src/ops/common/chunk_h.py -B 2 -A 2 | head -100

Repository: primatrix/pallas-kernel

Length of output: 3202


🏁 Script executed:

# Search Pallas documentation or similar usage patterns
rg "make_async_copy" . --type py -B 3 -A 3 | grep -A 3 -B 3 "dtype\|float32\|float16" | head -50

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Find the outer function that calls the kernel to understand dtype parameters
rg "def.*chunk.*h.*forward" src/ops/common/chunk_h.py -A 15 | head -80

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Look for the function that calls pallas_call at line 607
sed -n '450,560p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 3911


🏁 Script executed:

# Check if there's any documentation or comments about supported dtypes
rg "dtype|float16|bfloat16" src/ops/common/chunk_h.py | head -40

Repository: primatrix/pallas-kernel

Length of output: 1369


🏁 Script executed:

# Search for tests that use different dtypes
find . -type f -name "*test*.py" | xargs rg "chunk.*h.*forward\|_chunk_fwd_h_main_with_same_seq" | head -20

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check Pallas documentation or examples of make_async_copy with different dtypes
find . -type f -name "*.py" -path "*/pallas/*" | head -20

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Look for other uses of VMEM allocation and make_async_copy in the codebase
rg "VMEM.*jnp\." . --type py | head -30

Repository: primatrix/pallas-kernel

Length of output: 648


🏁 Script executed:

# Search for any float16 or bfloat16 dtype usage in this file
rg "float16|bfloat16" src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check the kernel function parameters more carefully - what dtype do k_ref, v_ref have?
sed -n '300,330p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 854


🏁 Script executed:

# Look for any type conversions or casting in the kernel before/after async_copy
sed -n '360,450p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 3491


🏁 Script executed:

# Check if there's any casting of the scratch tensors after async_copy.wait()
sed -n '380,410p' src/ops/common/chunk_h.py | grep -A 5 -B 5 "k_tile\|astype\|cast"

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Search for how k_tile is used after being loaded from scratch
rg "k_tile.*astype|v_tile.*astype" src/ops/common/chunk_h.py -B 2 -A 2

Repository: primatrix/pallas-kernel

Length of output: 368


🏁 Script executed:

# Check if there's any implicit casting or type promotion in the math operations
sed -n '410,445p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 1263


🏁 Script executed:

# Look for JAX/Pallas documentation on make_async_copy dtype support
rg "make_async_copy" . --type py -A 10 | grep -i "dtype\|type" | head -20

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Search for Pallas examples or tests with different dtypes
find . -type f -name "*.py" | xargs rg "float16|bfloat16" | grep -i "pallas\|dma\|async" | head -20

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Look for any JAX Pallas documentation or examples of type handling
find . -type f -name "*.md" -o -name "*.txt" | xargs grep -l "dtype\|float16\|bfloat16" 2>/dev/null | head -10

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check if there's any casting happening before the DMA or in the outer function
sed -n '495,560p' src/ops/common/chunk_h.py | grep -E "astype|cast|dtype"

Repository: primatrix/pallas-kernel

Length of output: 155


🏁 Script executed:

# Look at the full context of how k, v are transformed before being passed to the kernel
sed -n '490,520p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 1022


🏁 Script executed:

# Search for how other Pallas kernels in the repo handle dtype conversions for DMAs
find . -type f -name "*.py" -path "*/ops/*" | xargs rg "VMEM.*dtype\|make_async_copy" | head -30

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Look for README or documentation about Pallas usage
find . -name "README*" -o -name "*.md" | xargs cat 2>/dev/null | head -100

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Check if there are any tests that exercise this function with different dtypes
find . -type f -name "*test*.py" -o -name "*_test.py" | xargs grep -l "chunk_fwd_h\|float16\|bfloat16" 2>/dev/null

Repository: primatrix/pallas-kernel

Length of output: 356


🏁 Script executed:

# Look at the complete picture - search for where k and v dtypes are actually defined or constrained
rg "def chunk_fwd_h" src/ops/common/chunk_h.py -A 50 | head -80

Repository: primatrix/pallas-kernel

Length of output: 2773


🏁 Script executed:

# Check if there's a wrapper or casting before calling the kernel
sed -n '520,610p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 3249


🏁 Script executed:

# Check the test file to see what dtypes are being tested
cat tests/ops/gla/test_pallas_chunk_fwd_h.py | head -100

Repository: primatrix/pallas-kernel

Length of output: 2002


🏁 Script executed:

# Look for dtype testing in the test files
rg "dtype|float16|bfloat16" tests/ops/gla/test_pallas_chunk_fwd_h.py

Repository: primatrix/pallas-kernel

Length of output: 118


🏁 Script executed:

# Check if there's any JAX Pallas documentation or comments about make_async_copy
find . -type f -name "*.py" | xargs grep -B 3 -A 3 "def make_async_copy" 2>/dev/null | head -30

Repository: primatrix/pallas-kernel

Length of output: 49


🏁 Script executed:

# Search for any error handling or dtype checks in the function
sed -n '475,520p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 1518


🏁 Script executed:

# Check the full pallas_call invocation and argument passing
sed -n '600,640p' src/ops/common/chunk_h.py

Repository: primatrix/pallas-kernel

Length of output: 1359


🏁 Script executed:

# Look for any casting or dtype conversion of k, v before the kernel call
sed -n '480,610p' src/ops/common/chunk_h.py | grep -E "astype|cast|dtype|k\s*=" | head -20

Repository: primatrix/pallas-kernel

Length of output: 347


Keep DMA staging buffers in the transferred dtype.

These buffers are all hard-coded to jnp.float32, but the DMAs touch k/v/gk/h0 inputs and h_ref/ht_ref outputs whose dtypes are caller-dependent. Please verify make_async_copy supports cross-dtype transfers; if it does not, this path is only safe for full-f32 tensors. The safer pattern is: DMA in the original ref dtype, then cast to f32 after the wait for exp / dot, and cast back only when staging h_ref / ht_ref.

♻️ Sketch
-    k_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
-    v_scratch = pltpu.VMEM((2, BT, BV), jnp.float32)
+    k_scratch = pltpu.VMEM((2, BT, BK), k.dtype)
+    v_scratch = pltpu.VMEM((2, BT, BV), v.dtype)
     o_scratch = pltpu.VMEM((BK, BV), jnp.float32)
-    h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
-    ht_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
+    h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32 if states_in_fp32 else k.dtype)
+    ht_out_scratch = pltpu.VMEM((2, BK, BV), k.dtype)
...
-        h_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
+        h_scratch = pltpu.VMEM((2, BK, BV), h0.dtype)
...
-        gk_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
+        gk_scratch = pltpu.VMEM((2, BT, BK), gk.dtype)

Then cast the loaded tiles to jnp.float32 inside the kernel before the exp / dot math.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` around lines 558 - 591, The DMA staging VMEM
buffers (k_scratch, v_scratch, gk_scratch, h_scratch, o_scratch, h_out_scratch,
ht_out_scratch) are hard-coded to jnp.float32 but must preserve the caller dtype
to avoid illegal cross-dtype transfers; change their allocation to use the
originating dtype (e.g. k.dtype, v.dtype, gk.dtype, h0.dtype,
h_ref.dtype/ht_ref.dtype where appropriate) and then perform an explicit
jnp.astype(jnp.float32) inside the kernel before doing exp/dot, and cast back to
the original dtype only when writing h_ref/ht_ref back to HBM; also
confirm/adjust any make_async_copy usage to copy into the VMEM with the matching
dtype or add a comment guarding full-f32-only behavior if cross-dtype DMA is
unsupported.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
src/ops/common/chunk_h.py (4)

575-590: ⚠️ Potential issue | 🟠 Major

Keep DMA staging buffers in the transferred dtype.

k_scratch, v_scratch, gk_scratch, h_scratch, h_out_scratch, and ht_out_scratch are all hard-coded to jnp.float32, but the corresponding HBM refs are caller-typed. Unless this path is intentionally f32-only, non-f32 inputs/outputs now rely on cross-dtype DMA. The safer pattern is to stage in the source/destination dtype, cast to float32 only around the math, and cast back only when storing results.

For the current JAX TPU Pallas API, does `pltpu.make_async_copy` require source and destination refs to have the same dtype / element size, or are dtype conversions supported during HBM↔VMEM DMA? If conversions are not supported, what is the recommended pattern for bf16/f16 tiles that need float32 math in VMEM?

Also applies to: 596-599

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` around lines 575 - 590, The DMA staging VMEM
buffers (k_scratch, v_scratch, gk_scratch, h_scratch, h_out_scratch,
ht_out_scratch) are incorrectly hard-coded to jnp.float32; change their
pltpu.VMEM allocations to use the same dtype as the corresponding HBM refs (the
caller-typed tensors) so DMA transfers don't perform cross-dtype conversions,
then perform explicit casts to float32 only around the compute path (e.g., cast
the loaded VMEM tiles to jnp.float32 before math) and cast results back to the
original dtype before writing to HBM (before pltpu.make_async_copy or final
stores). Locate these buffers and the places that consume/produce them in
chunk_h.py and update the VMEM allocations and the local casts around the
computation/commit steps accordingly.

262-262: ⚠️ Potential issue | 🔴 Critical

Validate the megacore split before computing local_B.

local_B = B // 2 assumes the hard-coded grid=(2,) can partition the batch evenly. Odd B silently drops the tail item, and B == 1 makes this path do no work.

🛡️ Minimal fix
     NT = pl.cdiv(T, BT)
     NTS = BS // BT

+    assert B % 2 == 0 and B >= 2, "B must be even and >= 2 for megacore parallelism."
     local_B = B // 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` at line 262, Validate the megacore split before
computing local_B: check the batch size B against the hard-coded grid=(2,)
(i.e., ensure B >= 2 and B % 2 == 0) before using local_B = B // 2; if the check
fails, either raise a clear ValueError (e.g., "B must be even and >=2 for
grid=(2,)") or implement an explicit split policy (distribute the extra item or
assign ceil/floor to each half) and update any downstream code that uses
local_B; reference the variables local_B, B and the grid=(2,) assumption when
making the change.

520-530: ⚠️ Potential issue | 🔴 Critical

Reject partial trailing splits or size NS with ceil division.

This kernel emits a boundary state at t_i = 0, NTS, 2 * NTS, ..., so each batch item needs ceil(T / BS) slots. NS = T // BS only allocates floor(T / BS), which makes the last store walk past h_ref when T % BS != 0.

🛡️ Minimal safe fix
     assert BS % BT == 0, (
         f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
     )
+    assert T % BS == 0, (
+        f"The sequence length {T} must be divisible by `split_size` {BS} in the same-seq kernel."
+    )
     # N: the actual number of sequences in the batch with either equal or variable lengths

     N, NS = (
         B,
         T // BS,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` around lines 520 - 530, The code computes NS as
floor division (NS = T // BS) causing out-of-bounds when T % BS != 0; change NS
to ceiling division (NS = (T + BS - 1) // BS) so each batch item gets ceil(T/BS)
slots, and ensure any downstream logic that indexes into h_ref or iterates over
NS uses this new NS; update the assignment where N, NS are set (referencing BT,
BS, N, NS, B, T, h_ref) to use NS = (T + BS - 1) // BS.

466-480: ⚠️ Potential issue | 🔴 Critical

Guard output_ht() when output_final_state=False.

chunk_fwd_h_kernel_with_same_seq() passes None for the second out spec in that mode, but this callback still dereferences ht_ref.at[...] unconditionally. The default path will fail on the last tile.

🛡️ Minimal fix
         `@pl.when`(t_i + 1 == NT)
         def output_ht():
             nonlocal ht_buf, next_ht_buf
+            if ht_ref is None:
+                return
             ht_out_scratch_ref[ht_buf] = o_scratch_ref[...]

Also applies to: 564-569

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ops/common/chunk_h.py` around lines 466 - 480, The callback output_ht in
chunk_fwd_h_kernel_with_same_seq unconditionally dereferences ht_ref.at[...]
even when output_final_state is False (the second out spec is None), causing
failures; guard the bodies that reference ht_ref (both the main _async_copy
calls and the nested `@pl.when` blocks that use ht_ref.at[...] and
ht_out_scratch_ref.at[...]) with a condition based on output_final_state (or
skip registering the output_ht callback entirely when output_final_state is
False) so that ht_ref is only accessed when the second output spec is present;
apply the same guard to the analogous callback around lines 564-569.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@src/ops/common/chunk_h.py`:
- Around line 575-590: The DMA staging VMEM buffers (k_scratch, v_scratch,
gk_scratch, h_scratch, h_out_scratch, ht_out_scratch) are incorrectly hard-coded
to jnp.float32; change their pltpu.VMEM allocations to use the same dtype as the
corresponding HBM refs (the caller-typed tensors) so DMA transfers don't perform
cross-dtype conversions, then perform explicit casts to float32 only around the
compute path (e.g., cast the loaded VMEM tiles to jnp.float32 before math) and
cast results back to the original dtype before writing to HBM (before
pltpu.make_async_copy or final stores). Locate these buffers and the places that
consume/produce them in chunk_h.py and update the VMEM allocations and the local
casts around the computation/commit steps accordingly.
- Line 262: Validate the megacore split before computing local_B: check the
batch size B against the hard-coded grid=(2,) (i.e., ensure B >= 2 and B % 2 ==
0) before using local_B = B // 2; if the check fails, either raise a clear
ValueError (e.g., "B must be even and >=2 for grid=(2,)") or implement an
explicit split policy (distribute the extra item or assign ceil/floor to each
half) and update any downstream code that uses local_B; reference the variables
local_B, B and the grid=(2,) assumption when making the change.
- Around line 520-530: The code computes NS as floor division (NS = T // BS)
causing out-of-bounds when T % BS != 0; change NS to ceiling division (NS = (T +
BS - 1) // BS) so each batch item gets ceil(T/BS) slots, and ensure any
downstream logic that indexes into h_ref or iterates over NS uses this new NS;
update the assignment where N, NS are set (referencing BT, BS, N, NS, B, T,
h_ref) to use NS = (T + BS - 1) // BS.
- Around line 466-480: The callback output_ht in
chunk_fwd_h_kernel_with_same_seq unconditionally dereferences ht_ref.at[...]
even when output_final_state is False (the second out spec is None), causing
failures; guard the bodies that reference ht_ref (both the main _async_copy
calls and the nested `@pl.when` blocks that use ht_ref.at[...] and
ht_out_scratch_ref.at[...]) with a condition based on output_final_state (or
skip registering the output_ht callback entirely when output_final_state is
False) so that ht_ref is only accessed when the second output spec is present;
apply the same guard to the analogous callback around lines 564-569.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 64858d03-944f-47b7-9d9f-84c666256f66

📥 Commits

Reviewing files that changed from the base of the PR and between 6598d5a and 97abcdb.

📒 Files selected for processing (1)
  • src/ops/common/chunk_h.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant