Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the chunked forward pass for sequences of uniform length within the Pallas framework. The primary goal is to enhance performance by utilizing Pallas-specific features for conditional execution, explicit memory placement, and asynchronous data transfers. A new kernel is introduced that employs double buffering and asynchronous copies to efficiently prefetch input data, aiming to reduce memory bottlenecks and improve overall execution speed for this specific use case. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized kernel for sequences of the same length, leveraging asynchronous data copies for performance. The changes are generally in the right direction, but I've found a critical race condition in the new asynchronous copy logic that needs to be addressed. Additionally, a new test file appears to be a temporary profiling script with hardcoded paths and should be removed from this pull request. Please see my detailed comments below.
src/ops/common/chunk_h.py
Outdated
| copy_k0 = pltpu.make_async_copy( | ||
| k_ref.at[(0, 0, pl.dslice(0, BT), slice(None))], | ||
| k_scratch_ref.at[0], | ||
| local_copy_sem0, | ||
| ) | ||
| copy_k0.start() | ||
|
|
||
| copy_v0 = pltpu.make_async_copy( | ||
| v_ref.at[(0, 0, pl.dslice(0, BT), slice(None))], | ||
| v_scratch_ref.at[0], | ||
| local_copy_sem1, | ||
| ) | ||
| copy_v0.start() | ||
|
|
||
| if gk_ref is not None: | ||
| copy_gk0 = pltpu.make_async_copy( | ||
| gk_ref.at[(0, 0, pl.dslice(0, BT), slice(None))], | ||
| gk_scratch_ref.at[0], | ||
| local_copy_sem2, | ||
| ) | ||
| copy_gk0.start() | ||
|
|
||
| def body(i_t, carry): | ||
| b_h = carry | ||
| buf = jnp.mod(i_t , 2) | ||
| next_buf = jnp.mod(i_t + 1, 2) | ||
| copy_k0.wait() | ||
| copy_v0.wait() | ||
| if gk_ref is not None: | ||
| copy_gk0.wait() | ||
| i_s = i_t // NTS | ||
| @pl.when((i_t % NTS) == 0) | ||
| def store_fn(): | ||
| h_ref[0, i_s, 0] = b_h | ||
|
|
||
| @pl.when(i_t + 1 < NT) | ||
| def do_prefetch(): | ||
| t0 = (i_t + 1) * BT | ||
| pl_dslice = pl.dslice(t0, BT) | ||
| copy_k0 = pltpu.make_async_copy( | ||
| k_ref.at[(0, 0, pl_dslice, slice(None))], | ||
| k_scratch_ref.at[next_buf], | ||
| local_copy_sem0, | ||
| ) | ||
| copy_k0.start() | ||
| copy_v0 = pltpu.make_async_copy( | ||
| v_ref.at[(0, 0, pl_dslice, slice(None))], | ||
| v_scratch_ref.at[next_buf], | ||
| local_copy_sem1, | ||
| ) | ||
| copy_v0.start() | ||
|
|
||
| if gk_ref is not None: | ||
| copy_gk0 = pltpu.make_async_copy( | ||
| gk_ref.at[(0, 0, pl_dslice, slice(None))], | ||
| gk_scratch_ref.at[next_buf], | ||
| local_copy_sem2, | ||
| ) | ||
| copy_gk0.start() | ||
| k_tile = k_scratch_ref[buf] | ||
| v_tile = v_scratch_ref[buf] | ||
| if gk_ref is not None: | ||
| gk_tile = gk_scratch_ref[buf] | ||
| g_last = gk_tile[-1, :] | ||
| decay = jnp.exp(g_last) | ||
| b_h = b_h * decay[:, None] # [BK, BV] * [BK,1] | ||
| k_tile = (k_tile * jnp.exp(g_last[None, :] - gk_tile)).astype(gk_tile.dtype) | ||
|
|
||
| b_h = b_h + jax.lax.dot(k_tile.T, v_tile) | ||
|
|
||
| return b_h | ||
|
|
||
| b_h = lax.fori_loop(0, NT, body, b_h) |
There was a problem hiding this comment.
There is a critical race condition in the asynchronous data copy logic within the fori_loop. The async copy handles (copy_k0, copy_v0, copy_gk0) are defined outside the loop's body and are reassigned within the do_prefetch function. Due to Python's scoping rules, this creates new local variables within do_prefetch, and the outer handles are never updated. Consequently, the copy_*.wait() calls at the beginning of the loop body always wait on the handle for the very first chunk, not the one prefetched in the previous iteration. This can lead to using data that has not been copied yet.
To fix this, the async copy handles must be passed through the loop's carry state. This ensures that each iteration waits for the correct data and launches the prefetch for the next one. I've provided a suggestion that refactors the loop to correctly manage the handle state using lax.cond for the conditional prefetch.
copy_k = pltpu.make_async_copy(
k_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
k_scratch_ref.at[0],
local_copy_sem0,
)
copy_k.start()
copy_v = pltpu.make_async_copy(
v_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
v_scratch_ref.at[0],
local_copy_sem1,
)
copy_v.start()
if gk_ref is not None:
copy_gk = pltpu.make_async_copy(
gk_ref.at[(0, 0, pl.dslice(0, BT), slice(None))],
gk_scratch_ref.at[0],
local_copy_sem2,
)
copy_gk.start()
else:
copy_gk = None
def body(i_t, carry):
b_h, copy_k, copy_v, copy_gk = carry
buf = jnp.mod(i_t, 2)
copy_k.wait()
copy_v.wait()
if gk_ref is not None:
copy_gk.wait()
i_s = i_t // NTS
@pl.when((i_t % NTS) == 0)
def store_fn():
h_ref[0, i_s, 0] = b_h
def prefetch_fn(op):
next_buf = jnp.mod(i_t + 1, 2)
t0 = (i_t + 1) * BT
pl_dslice = pl.dslice(t0, BT)
next_copy_k = pltpu.make_async_copy(
k_ref.at[(0, 0, pl_dslice, slice(None))],
k_scratch_ref.at[next_buf],
local_copy_sem0,
)
next_copy_k.start()
next_copy_v = pltpu.make_async_copy(
v_ref.at[(0, 0, pl_dslice, slice(None))],
v_scratch_ref.at[next_buf],
local_copy_sem1,
)
next_copy_v.start()
if gk_ref is not None:
next_copy_gk = pltpu.make_async_copy(
gk_ref.at[(0, 0, pl_dslice, slice(None))],
gk_scratch_ref.at[next_buf],
local_copy_sem2,
)
next_copy_gk.start()
else:
next_copy_gk = op[2]
return next_copy_k, next_copy_v, next_copy_gk
k_tile = k_scratch_ref[buf]
v_tile = v_scratch_ref[buf]
if gk_ref is not None:
gk_tile = gk_scratch_ref[buf]
g_last = gk_tile[-1, :]
decay = jnp.exp(g_last)
b_h = b_h * decay[:, None] # [BK, BV] * [BK,1]
k_tile = (k_tile * jnp.exp(g_last[None, :] - gk_tile)).astype(gk_tile.dtype)
b_h = b_h + jax.lax.dot(k_tile.T, v_tile)
next_copies = lax.cond(i_t + 1 < NT, prefetch_fn, lambda x: x, (copy_k, copy_v, copy_gk))
return (b_h, *next_copies)
b_h, _, _, _ = lax.fori_loop(0, NT, body, (b_h, copy_k, copy_v, copy_gk))| import jax | ||
| B, T, H, K = 4, 1024, 4, 128 | ||
| chunk_size = 64 | ||
| N = 1 | ||
| cu = None | ||
| if cu is not None: | ||
| N = len(cu) - 1 | ||
| else: | ||
| N = B | ||
| key = jax.random.PRNGKey(1) | ||
| k = jax.random.normal(key, (B, T, H, K)) | ||
| v = jax.random.normal(key, (B, T, H, K)) | ||
| gk = jax.random.normal(key, (B, T, H, K)) | ||
| h0 = jax.random.normal(key, (N, H, K, K)) | ||
|
|
||
| from src.ops.common.chunk_h import chunk_fwd_h_kernel, chunk_fwd_h_kernel_with_same_seq | ||
|
|
||
| def _run_pallas( | ||
| k, | ||
| v, | ||
| gk=None, | ||
| h0=None, | ||
| chunk_size=64, | ||
| *, | ||
| cu_seqlens=None, | ||
| ): | ||
| h, ht = chunk_fwd_h_kernel_with_same_seq( | ||
| k, | ||
| v, | ||
| gk=gk, | ||
| h0=h0, | ||
| chunk_size=chunk_size, | ||
| output_final_state=True, | ||
| ) | ||
| return h, ht | ||
|
|
||
|
|
||
| pallas_h, pallas_ht = _run_pallas( | ||
| k, | ||
| v, | ||
| gk=gk, | ||
| h0=h0, | ||
| chunk_size=chunk_size, | ||
| cu_seqlens=cu, | ||
| ) | ||
| jax.block_until_ready(pallas_h) | ||
| jax.block_until_ready(pallas_ht) | ||
| jax.profiler.start_trace("/home/gcpuser/profile") | ||
| for i in range(3): | ||
| pallas_h, pallas_ht = _run_pallas( | ||
| k, | ||
| v, | ||
| gk=gk, | ||
| h0=h0, | ||
| chunk_size=chunk_size, | ||
| cu_seqlens=cu, | ||
| ) | ||
| jax.block_until_ready(pallas_h) | ||
| jax.block_until_ready(pallas_ht) | ||
| jax.profiler.stop_trace() No newline at end of file |
There was a problem hiding this comment.
This new file appears to be a temporary profiling script rather than a formal test case. It contains several issues:
- Hardcoded Path:
jax.profiler.start_trace("/home/gcpuser/profile")uses a hardcoded, user-specific path that will not work for other developers or in CI environments. - Misleading Filename: The name
test.pyimplies it contains unit tests with assertions, but it lacks any. - Dead/Unused Code: The
cuvariable is alwaysNone, and thecu_seqlensparameter in_run_pallasis unused.
This type of script is useful for development but should not be committed to the main repository as a test. Please remove this file from the pull request. If profiling capabilities are desired, they should be implemented as a proper benchmark or a separate, clearly-named script outside the test suite.
9e56f8c to
852b95b
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ops/common/chunk_h.py`:
- Around line 390-409: The allocation for number of split states (NS) uses floor
division and can be too small; change the calculation of NS from T // BS to use
ceil(T / BS) (e.g., math.ceil(T / BS) or equivalent integer-safe expression) so
it matches emitted outputs from the kernel (which writes at t_i = 0, NTS, 2*NTS,
...), and update any related buffers/allocations that reference NS (the buffers
used by h_out_scratch_ref, sems, or other arrays) so they are sized by the new
ceil-based NS; also apply the same fix to the other occurrence referenced in the
review (the similar block around the 520-530 region).
- Around line 466-480: The output_ht() callback unconditionally dereferences
ht_ref via ht_ref.at[...] which can be None when output_final_state=False
(out_specs has None); wrap the entire body of output_ht (all calls to
_async_copy inside output_ht) with a guard checking if ht_ref is not None before
performing any _async_copy to ht_ref.at, and apply the same pattern for the
similar optional-output block around lines 564–569 so no _async_copy targets
ht_ref when ht_ref is None.
- Around line 558-591: The DMA staging VMEM buffers (k_scratch, v_scratch,
gk_scratch, h_scratch, o_scratch, h_out_scratch, ht_out_scratch) are hard-coded
to jnp.float32 but must preserve the caller dtype to avoid illegal cross-dtype
transfers; change their allocation to use the originating dtype (e.g. k.dtype,
v.dtype, gk.dtype, h0.dtype, h_ref.dtype/ht_ref.dtype where appropriate) and
then perform an explicit jnp.astype(jnp.float32) inside the kernel before doing
exp/dot, and cast back to the original dtype only when writing h_ref/ht_ref back
to HBM; also confirm/adjust any make_async_copy usage to copy into the VMEM with
the matching dtype or add a comment guarding full-f32-only behavior if
cross-dtype DMA is unsupported.
src/ops/common/chunk_h.py
Outdated
| i_s = t_i // NTS | ||
| @pl.when((t_i % NTS) == 0) | ||
| def store_fn(): | ||
| nonlocal h_o_buf, next_h_o_buf | ||
| @pl.when(h_o_buf > 0) | ||
| def wait_prev(): | ||
| nonlocal h_o_buf, next_h_o_buf | ||
| pre_nts = i - NTS | ||
| b_pre_i, h_pre_i, k_pre_i, v_pre_i, t_pre_i = get_index(pre_nts) | ||
| k_pre_slice = pl.ds(k_pre_i * BK, BK) | ||
| v_pre_slice = pl.ds(v_pre_i * BV, BV) | ||
| b_pre_slice = b_part_i * local_B + b_pre_i | ||
| _async_copy( | ||
| h_out_scratch_ref.at[next_h_o_buf], | ||
| h_ref.at[b_pre_slice, t_pre_i // NTS, h_pre_i, k_pre_slice, v_pre_slice], | ||
| sems.at[4, next_h_o_buf], | ||
| True, | ||
| ) | ||
| h_out_scratch_ref[h_o_buf] = o_scratch_ref[...] | ||
| _async_copy(h_out_scratch_ref.at[h_o_buf], h_ref.at[b_slice, i_s, h_i, k_slice, v_slice], sems.at[4, h_o_buf]) |
There was a problem hiding this comment.
NS must match the number of emitted split states.
This kernel stores at t_i = 0, NTS, 2 * NTS, ... per sequence, which is ceil(T / BS) outputs, but NS = T // BS only allocates floor(T / BS). Any trailing partial split will write h_ref[b_slice, i_s, ...] past the end.
🛡️ Minimal safe fix
BS = BT if split_size is None else split_size
assert BS % BT == 0, (
f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
)
+ assert T % BS == 0, (
+ f"The sequence length {T} must be divisible by `split_size` {BS} in the same-seq kernel."
+ )
# N: the actual number of sequences in the batch with either equal or variable lengths
N, NS = (
B,
T // BS,If the trailing partial split is meant to be supported, switch this to math.ceil(T / BS) instead of floor division.
Also applies to: 520-530
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ops/common/chunk_h.py` around lines 390 - 409, The allocation for number
of split states (NS) uses floor division and can be too small; change the
calculation of NS from T // BS to use ceil(T / BS) (e.g., math.ceil(T / BS) or
equivalent integer-safe expression) so it matches emitted outputs from the
kernel (which writes at t_i = 0, NTS, 2*NTS, ...), and update any related
buffers/allocations that reference NS (the buffers used by h_out_scratch_ref,
sems, or other arrays) so they are sized by the new ceil-based NS; also apply
the same fix to the other occurrence referenced in the review (the similar block
around the 520-530 region).
| out_shape = [ | ||
| jax.ShapeDtypeStruct( | ||
| shape=(N, NS, H, K, V), dtype=k.dtype if not states_in_fp32 else jnp.float32 | ||
| ) | ||
| ] | ||
| out_specs = [pl.BlockSpec(memory_space=pltpu.HBM)] | ||
| if output_final_state: | ||
| out_shape.append(jax.ShapeDtypeStruct(shape=(N, H, K, V), dtype=k.dtype)) | ||
| out_specs.append(pl.BlockSpec(memory_space=pltpu.HBM)) | ||
| else: | ||
| out_shape.append(None) | ||
| out_specs.append(None) | ||
|
|
||
| in_specs = [ | ||
| pl.BlockSpec(memory_space=pltpu.HBM), | ||
| pl.BlockSpec(memory_space=pltpu.HBM), | ||
| ] | ||
| k_scratch = pltpu.VMEM((2, BT, BK), jnp.float32) | ||
| v_scratch = pltpu.VMEM((2, BT, BV), jnp.float32) | ||
| o_scratch = pltpu.VMEM((BK, BV), jnp.float32) | ||
| h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32) | ||
| ht_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32) | ||
| scratch_shapes = [k_scratch, v_scratch] | ||
| if h0 is not None: | ||
| in_specs.append(pl.BlockSpec(memory_space=pltpu.HBM)) | ||
| h_scratch = pltpu.VMEM((2, BK, BV), jnp.float32) | ||
| scratch_shapes.append(h_scratch) | ||
| else: | ||
| scratch_shapes.append(None) | ||
| in_specs.append(None) | ||
| if gk is not None: | ||
| in_specs.append(pl.BlockSpec(memory_space=pltpu.HBM)) | ||
| gk_scratch = pltpu.VMEM((2, BT, BK), jnp.float32) | ||
| scratch_shapes.append(gk_scratch) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's explore the repository structure
find . -type f -name "chunk_h.py" | head -20Repository: primatrix/pallas-kernel
Length of output: 94
🏁 Script executed:
# Let's examine the specific file and lines mentioned
cat -n src/ops/common/chunk_h.py | sed -n '540,610p'Repository: primatrix/pallas-kernel
Length of output: 3135
🏁 Script executed:
# Search for make_async_copy usage in this file to understand how DMAs are performed
rg "make_async_copy|async_copy" src/ops/common/chunk_h.py -B 3 -A 3Repository: primatrix/pallas-kernel
Length of output: 4496
🏁 Script executed:
# Check the imports to understand what libraries are being used
head -50 src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 1342
🏁 Script executed:
# Find the kernel function definition and see how inputs are used
sed -n '200,350p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 3994
🏁 Script executed:
# Search for where k, v, gk, h0 dtypes are established in the outer function
rg "def.*chunk.*forward|def _chunk_fwd_h_main" src/ops/common/chunk_h.py -A 30 | head -60Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check if there's any dtype casting happening after DMA waits in the kernel
rg "astype|cast|dtype" src/ops/common/chunk_h.py -B 2 -A 2 | head -100Repository: primatrix/pallas-kernel
Length of output: 3202
🏁 Script executed:
# Search Pallas documentation or similar usage patterns
rg "make_async_copy" . --type py -B 3 -A 3 | grep -A 3 -B 3 "dtype\|float32\|float16" | head -50Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Find the outer function that calls the kernel to understand dtype parameters
rg "def.*chunk.*h.*forward" src/ops/common/chunk_h.py -A 15 | head -80Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Look for the function that calls pallas_call at line 607
sed -n '450,560p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 3911
🏁 Script executed:
# Check if there's any documentation or comments about supported dtypes
rg "dtype|float16|bfloat16" src/ops/common/chunk_h.py | head -40Repository: primatrix/pallas-kernel
Length of output: 1369
🏁 Script executed:
# Search for tests that use different dtypes
find . -type f -name "*test*.py" | xargs rg "chunk.*h.*forward\|_chunk_fwd_h_main_with_same_seq" | head -20Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check Pallas documentation or examples of make_async_copy with different dtypes
find . -type f -name "*.py" -path "*/pallas/*" | head -20Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Look for other uses of VMEM allocation and make_async_copy in the codebase
rg "VMEM.*jnp\." . --type py | head -30Repository: primatrix/pallas-kernel
Length of output: 648
🏁 Script executed:
# Search for any float16 or bfloat16 dtype usage in this file
rg "float16|bfloat16" src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check the kernel function parameters more carefully - what dtype do k_ref, v_ref have?
sed -n '300,330p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 854
🏁 Script executed:
# Look for any type conversions or casting in the kernel before/after async_copy
sed -n '360,450p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 3491
🏁 Script executed:
# Check if there's any casting of the scratch tensors after async_copy.wait()
sed -n '380,410p' src/ops/common/chunk_h.py | grep -A 5 -B 5 "k_tile\|astype\|cast"Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Search for how k_tile is used after being loaded from scratch
rg "k_tile.*astype|v_tile.*astype" src/ops/common/chunk_h.py -B 2 -A 2Repository: primatrix/pallas-kernel
Length of output: 368
🏁 Script executed:
# Check if there's any implicit casting or type promotion in the math operations
sed -n '410,445p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 1263
🏁 Script executed:
# Look for JAX/Pallas documentation on make_async_copy dtype support
rg "make_async_copy" . --type py -A 10 | grep -i "dtype\|type" | head -20Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Search for Pallas examples or tests with different dtypes
find . -type f -name "*.py" | xargs rg "float16|bfloat16" | grep -i "pallas\|dma\|async" | head -20Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Look for any JAX Pallas documentation or examples of type handling
find . -type f -name "*.md" -o -name "*.txt" | xargs grep -l "dtype\|float16\|bfloat16" 2>/dev/null | head -10Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check if there's any casting happening before the DMA or in the outer function
sed -n '495,560p' src/ops/common/chunk_h.py | grep -E "astype|cast|dtype"Repository: primatrix/pallas-kernel
Length of output: 155
🏁 Script executed:
# Look at the full context of how k, v are transformed before being passed to the kernel
sed -n '490,520p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 1022
🏁 Script executed:
# Search for how other Pallas kernels in the repo handle dtype conversions for DMAs
find . -type f -name "*.py" -path "*/ops/*" | xargs rg "VMEM.*dtype\|make_async_copy" | head -30Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Look for README or documentation about Pallas usage
find . -name "README*" -o -name "*.md" | xargs cat 2>/dev/null | head -100Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Check if there are any tests that exercise this function with different dtypes
find . -type f -name "*test*.py" -o -name "*_test.py" | xargs grep -l "chunk_fwd_h\|float16\|bfloat16" 2>/dev/nullRepository: primatrix/pallas-kernel
Length of output: 356
🏁 Script executed:
# Look at the complete picture - search for where k and v dtypes are actually defined or constrained
rg "def chunk_fwd_h" src/ops/common/chunk_h.py -A 50 | head -80Repository: primatrix/pallas-kernel
Length of output: 2773
🏁 Script executed:
# Check if there's a wrapper or casting before calling the kernel
sed -n '520,610p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 3249
🏁 Script executed:
# Check the test file to see what dtypes are being tested
cat tests/ops/gla/test_pallas_chunk_fwd_h.py | head -100Repository: primatrix/pallas-kernel
Length of output: 2002
🏁 Script executed:
# Look for dtype testing in the test files
rg "dtype|float16|bfloat16" tests/ops/gla/test_pallas_chunk_fwd_h.pyRepository: primatrix/pallas-kernel
Length of output: 118
🏁 Script executed:
# Check if there's any JAX Pallas documentation or comments about make_async_copy
find . -type f -name "*.py" | xargs grep -B 3 -A 3 "def make_async_copy" 2>/dev/null | head -30Repository: primatrix/pallas-kernel
Length of output: 49
🏁 Script executed:
# Search for any error handling or dtype checks in the function
sed -n '475,520p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 1518
🏁 Script executed:
# Check the full pallas_call invocation and argument passing
sed -n '600,640p' src/ops/common/chunk_h.pyRepository: primatrix/pallas-kernel
Length of output: 1359
🏁 Script executed:
# Look for any casting or dtype conversion of k, v before the kernel call
sed -n '480,610p' src/ops/common/chunk_h.py | grep -E "astype|cast|dtype|k\s*=" | head -20Repository: primatrix/pallas-kernel
Length of output: 347
Keep DMA staging buffers in the transferred dtype.
These buffers are all hard-coded to jnp.float32, but the DMAs touch k/v/gk/h0 inputs and h_ref/ht_ref outputs whose dtypes are caller-dependent. Please verify make_async_copy supports cross-dtype transfers; if it does not, this path is only safe for full-f32 tensors. The safer pattern is: DMA in the original ref dtype, then cast to f32 after the wait for exp / dot, and cast back only when staging h_ref / ht_ref.
♻️ Sketch
- k_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
- v_scratch = pltpu.VMEM((2, BT, BV), jnp.float32)
+ k_scratch = pltpu.VMEM((2, BT, BK), k.dtype)
+ v_scratch = pltpu.VMEM((2, BT, BV), v.dtype)
o_scratch = pltpu.VMEM((BK, BV), jnp.float32)
- h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
- ht_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
+ h_out_scratch = pltpu.VMEM((2, BK, BV), jnp.float32 if states_in_fp32 else k.dtype)
+ ht_out_scratch = pltpu.VMEM((2, BK, BV), k.dtype)
...
- h_scratch = pltpu.VMEM((2, BK, BV), jnp.float32)
+ h_scratch = pltpu.VMEM((2, BK, BV), h0.dtype)
...
- gk_scratch = pltpu.VMEM((2, BT, BK), jnp.float32)
+ gk_scratch = pltpu.VMEM((2, BT, BK), gk.dtype)Then cast the loaded tiles to jnp.float32 inside the kernel before the exp / dot math.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ops/common/chunk_h.py` around lines 558 - 591, The DMA staging VMEM
buffers (k_scratch, v_scratch, gk_scratch, h_scratch, o_scratch, h_out_scratch,
ht_out_scratch) are hard-coded to jnp.float32 but must preserve the caller dtype
to avoid illegal cross-dtype transfers; change their allocation to use the
originating dtype (e.g. k.dtype, v.dtype, gk.dtype, h0.dtype,
h_ref.dtype/ht_ref.dtype where appropriate) and then perform an explicit
jnp.astype(jnp.float32) inside the kernel before doing exp/dot, and cast back to
the original dtype only when writing h_ref/ht_ref back to HBM; also
confirm/adjust any make_async_copy usage to copy into the VMEM with the matching
dtype or add a comment guarding full-f32-only behavior if cross-dtype DMA is
unsupported.
There was a problem hiding this comment.
♻️ Duplicate comments (4)
src/ops/common/chunk_h.py (4)
575-590:⚠️ Potential issue | 🟠 MajorKeep DMA staging buffers in the transferred dtype.
k_scratch,v_scratch,gk_scratch,h_scratch,h_out_scratch, andht_out_scratchare all hard-coded tojnp.float32, but the corresponding HBM refs are caller-typed. Unless this path is intentionally f32-only, non-f32 inputs/outputs now rely on cross-dtype DMA. The safer pattern is to stage in the source/destination dtype, cast tofloat32only around the math, and cast back only when storing results.For the current JAX TPU Pallas API, does `pltpu.make_async_copy` require source and destination refs to have the same dtype / element size, or are dtype conversions supported during HBM↔VMEM DMA? If conversions are not supported, what is the recommended pattern for bf16/f16 tiles that need float32 math in VMEM?Also applies to: 596-599
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ops/common/chunk_h.py` around lines 575 - 590, The DMA staging VMEM buffers (k_scratch, v_scratch, gk_scratch, h_scratch, h_out_scratch, ht_out_scratch) are incorrectly hard-coded to jnp.float32; change their pltpu.VMEM allocations to use the same dtype as the corresponding HBM refs (the caller-typed tensors) so DMA transfers don't perform cross-dtype conversions, then perform explicit casts to float32 only around the compute path (e.g., cast the loaded VMEM tiles to jnp.float32 before math) and cast results back to the original dtype before writing to HBM (before pltpu.make_async_copy or final stores). Locate these buffers and the places that consume/produce them in chunk_h.py and update the VMEM allocations and the local casts around the computation/commit steps accordingly.
262-262:⚠️ Potential issue | 🔴 CriticalValidate the megacore split before computing
local_B.
local_B = B // 2assumes the hard-codedgrid=(2,)can partition the batch evenly. OddBsilently drops the tail item, andB == 1makes this path do no work.🛡️ Minimal fix
NT = pl.cdiv(T, BT) NTS = BS // BT + assert B % 2 == 0 and B >= 2, "B must be even and >= 2 for megacore parallelism." local_B = B // 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ops/common/chunk_h.py` at line 262, Validate the megacore split before computing local_B: check the batch size B against the hard-coded grid=(2,) (i.e., ensure B >= 2 and B % 2 == 0) before using local_B = B // 2; if the check fails, either raise a clear ValueError (e.g., "B must be even and >=2 for grid=(2,)") or implement an explicit split policy (distribute the extra item or assign ceil/floor to each half) and update any downstream code that uses local_B; reference the variables local_B, B and the grid=(2,) assumption when making the change.
520-530:⚠️ Potential issue | 🔴 CriticalReject partial trailing splits or size
NSwith ceil division.This kernel emits a boundary state at
t_i = 0, NTS, 2 * NTS, ..., so each batch item needsceil(T / BS)slots.NS = T // BSonly allocatesfloor(T / BS), which makes the last store walk pasth_refwhenT % BS != 0.🛡️ Minimal safe fix
assert BS % BT == 0, ( f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" ) + assert T % BS == 0, ( + f"The sequence length {T} must be divisible by `split_size` {BS} in the same-seq kernel." + ) # N: the actual number of sequences in the batch with either equal or variable lengths N, NS = ( B, T // BS,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ops/common/chunk_h.py` around lines 520 - 530, The code computes NS as floor division (NS = T // BS) causing out-of-bounds when T % BS != 0; change NS to ceiling division (NS = (T + BS - 1) // BS) so each batch item gets ceil(T/BS) slots, and ensure any downstream logic that indexes into h_ref or iterates over NS uses this new NS; update the assignment where N, NS are set (referencing BT, BS, N, NS, B, T, h_ref) to use NS = (T + BS - 1) // BS.
466-480:⚠️ Potential issue | 🔴 CriticalGuard
output_ht()whenoutput_final_state=False.
chunk_fwd_h_kernel_with_same_seq()passesNonefor the second out spec in that mode, but this callback still dereferencesht_ref.at[...]unconditionally. The default path will fail on the last tile.🛡️ Minimal fix
`@pl.when`(t_i + 1 == NT) def output_ht(): nonlocal ht_buf, next_ht_buf + if ht_ref is None: + return ht_out_scratch_ref[ht_buf] = o_scratch_ref[...]Also applies to: 564-569
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ops/common/chunk_h.py` around lines 466 - 480, The callback output_ht in chunk_fwd_h_kernel_with_same_seq unconditionally dereferences ht_ref.at[...] even when output_final_state is False (the second out spec is None), causing failures; guard the bodies that reference ht_ref (both the main _async_copy calls and the nested `@pl.when` blocks that use ht_ref.at[...] and ht_out_scratch_ref.at[...]) with a condition based on output_final_state (or skip registering the output_ht callback entirely when output_final_state is False) so that ht_ref is only accessed when the second output spec is present; apply the same guard to the analogous callback around lines 564-569.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@src/ops/common/chunk_h.py`:
- Around line 575-590: The DMA staging VMEM buffers (k_scratch, v_scratch,
gk_scratch, h_scratch, h_out_scratch, ht_out_scratch) are incorrectly hard-coded
to jnp.float32; change their pltpu.VMEM allocations to use the same dtype as the
corresponding HBM refs (the caller-typed tensors) so DMA transfers don't perform
cross-dtype conversions, then perform explicit casts to float32 only around the
compute path (e.g., cast the loaded VMEM tiles to jnp.float32 before math) and
cast results back to the original dtype before writing to HBM (before
pltpu.make_async_copy or final stores). Locate these buffers and the places that
consume/produce them in chunk_h.py and update the VMEM allocations and the local
casts around the computation/commit steps accordingly.
- Line 262: Validate the megacore split before computing local_B: check the
batch size B against the hard-coded grid=(2,) (i.e., ensure B >= 2 and B % 2 ==
0) before using local_B = B // 2; if the check fails, either raise a clear
ValueError (e.g., "B must be even and >=2 for grid=(2,)") or implement an
explicit split policy (distribute the extra item or assign ceil/floor to each
half) and update any downstream code that uses local_B; reference the variables
local_B, B and the grid=(2,) assumption when making the change.
- Around line 520-530: The code computes NS as floor division (NS = T // BS)
causing out-of-bounds when T % BS != 0; change NS to ceiling division (NS = (T +
BS - 1) // BS) so each batch item gets ceil(T/BS) slots, and ensure any
downstream logic that indexes into h_ref or iterates over NS uses this new NS;
update the assignment where N, NS are set (referencing BT, BS, N, NS, B, T,
h_ref) to use NS = (T + BS - 1) // BS.
- Around line 466-480: The callback output_ht in
chunk_fwd_h_kernel_with_same_seq unconditionally dereferences ht_ref.at[...]
even when output_final_state is False (the second out spec is None), causing
failures; guard the bodies that reference ht_ref (both the main _async_copy
calls and the nested `@pl.when` blocks that use ht_ref.at[...] and
ht_out_scratch_ref.at[...]) with a condition based on output_final_state (or
skip registering the output_ht callback entirely when output_final_state is
False) so that ht_ref is only accessed when the second output spec is present;
apply the same guard to the analogous callback around lines 564-569.
1cf253b to
d838d7f
Compare
95d76b2 to
e3ce196
Compare
2d64f6f to
f856b54
Compare
optim
async copy
Summary by CodeRabbit
New Features
Improvements
Tests