Skip to content

[Pallas] Fix scalar .begin index not collapsing tensor dimensions#1972

Draft
norx1991 wants to merge 1 commit intomainfrom
yifeixu/pallas-scalar-index-fix
Draft

[Pallas] Fix scalar .begin index not collapsing tensor dimensions#1972
norx1991 wants to merge 1 commit intomainfrom
yifeixu/pallas-scalar-index-fix

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 7, 2026

Summary

  • Add TileBeginOrigin to get_block_id's accepted types so .begin SymInts get a block_id
  • Emit literal 0 for .begin indices in _pallas_index_str to collapse the dim, mirroring Triton's scalar SymInt handling
  • Fix adjust_block_size_constraints fallback that incorrectly assumed spec ordering matches tensor dim ordering
  • Skip out_pos increment for collapsed dims so None positions stay correct
  • Remove xfail from test added in [Pallas] Add xfail tests for scalar .begin index not collapsing dims #1971

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 7, 2026
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-test-only branch from e419f95 to 643a8d4 Compare April 7, 2026 18:06
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from a52bee5 to 2180150 Compare April 7, 2026 18:10
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-test-only branch from 643a8d4 to 8a58474 Compare April 7, 2026 18:10
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 2180150 to 7eff4ec Compare April 7, 2026 18:11
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-test-only branch from 8a58474 to c215b44 Compare April 7, 2026 18:18
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 7eff4ec to 47aff58 Compare April 7, 2026 18:18
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-test-only branch from c215b44 to 8d9cb31 Compare April 7, 2026 19:10
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 47aff58 to 042d204 Compare April 7, 2026 19:14
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-test-only branch from 8d9cb31 to a2ebcf1 Compare April 7, 2026 20:25
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 042d204 to 10b4184 Compare April 7, 2026 20:28
@norx1991 norx1991 changed the base branch from yifeixu/pallas-scalar-index-test-only to main April 7, 2026 20:44
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 10b4184 to 32711fc Compare April 7, 2026 20:44
When a kernel uses tile.begin as a scalar index, the Pallas codegen
emitted ':' keeping the dimension alive. Operations like .T then
failed with a rank mismatch (e.g. 2D permutation on a 3D tensor).

Fixes:

1. compile_environment.py: Add TileBeginOrigin to get_block_id's
   accepted types so .begin SymInts get a block_id.

2. memory_ops.py: Add _is_scalar_tile_offset helper and modify
   _pallas_index_str to emit literal 0 for .begin indices (collapsing
   the dim), mirroring Triton's scalar SymInt handling. Record in
   dim_map so BlockSpec tiles the dim correctly. Skip out_pos
   increment for collapsed dims so None positions stay correct.

3. backend.py: Skip alignment constraints when no tensor dim mapping
   is found, fixing a fallback that incorrectly assumed spec ordering
   matches tensor dim ordering.
@norx1991 norx1991 force-pushed the yifeixu/pallas-scalar-index-fix branch from 32711fc to 441d03e Compare April 8, 2026 22:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant