Skip to content

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002

Open
yarongmu-google wants to merge 9 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping
Open

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002
yarongmu-google wants to merge 9 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping

Conversation

@yarongmu-google
Copy link
Copy Markdown

@yarongmu-google yarongmu-google commented Apr 10, 2026

The kernels currently has 2 common issues that need support:

  1. Long types are not supported in Pallas/Mosaic (XLA does support it but Helion doesn't go through XLA).
  2. Directly indexing into vectors on HBM.
    Add CI workflow #2 is the bigger fix here.

After this PR:

Benchmark Results

Implementation Time (ms) Speedup

helion 0.3826 1.10x
torch 0.4208 1.00x (ref)

norx1991 and others added 7 commits April 2, 2026 19:06
…errors and fix zero division in block size calculation
…py.py to avoid unaligned HBM gather

This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants