Skip to content

[Pallas] Fix fori_loop multi-dim inner loop index unflattening#1995

Closed
thcmbs wants to merge 1 commit intopytorch:mainfrom
thcmbs:tcombes/fix-fori-loop-multidim-index
Closed

[Pallas] Fix fori_loop multi-dim inner loop index unflattening#1995
thcmbs wants to merge 1 commit intopytorch:mainfrom
thcmbs:tcombes/fix-fori-loop-multidim-index

Conversation

@thcmbs
Copy link
Copy Markdown
Collaborator

@thcmbs thcmbs commented Apr 9, 2026

Fix out-of-bounds DMA access when fori_loop iterates over multiple inner block dimensions

The flat _j loop variable was used directly for all dimensions' offsets; now we unflatten it into per-dimension indices (_j_0, _j_1, ...) via divmod decomposition

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
@thcmbs thcmbs marked this pull request as ready for review April 9, 2026 15:02
@norx1991
Copy link
Copy Markdown
Contributor

norx1991 commented Apr 9, 2026

FYI, there are some comments under #1917. We see slightly better performance with double fori_loop, and it also seems a more natural way to implement this.

@thcmbs
Copy link
Copy Markdown
Collaborator Author

thcmbs commented Apr 13, 2026

Thanks for sharing! Given the perf diff, there is no debate :)

@thcmbs
Copy link
Copy Markdown
Collaborator Author

thcmbs commented Apr 17, 2026

Closing given that #1917 was merged.

@thcmbs thcmbs closed this Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants