fix: validate sequence length divisibility for zigzag ring attention#57
Open
Bhavyashah20 wants to merge 1 commit into
Open
fix: validate sequence length divisibility for zigzag ring attention#57Bhavyashah20 wants to merge 1 commit into
Bhavyashah20 wants to merge 1 commit into
Conversation
ZigZag ring attention requires seq_len % (2 * world_size) == 0. When this constraint is violated, torch.Tensor.chunk() produces unequal chunks, causing a size mismatch error in the backward pass. Add validation in extract_local() with a clear error message explaining the issue and how to fix it. Add pytest suite for prepare_inputs.py. Fixes jzhang38#47
Author
|
Bumping this in case it got buried — the zigzag attention backward pass fails with a cryptic size mismatch error when sequence length isn't divisible by The fix adds an early validation check in |
Author
|
Hi, just following up on this — it's been about a month since I opened this. The validation catches a silent failure in zigzag ring attention that's otherwise very hard to debug. Happy to revise if anything looks off. Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Found this while reading through the codebase. The zigzag attention
backward pass fails with a cryptic size mismatch error when sequence
length isn't divisible by 2 * world_size. Took me a while to trace
it back to the unequal chunks from torch.chunk().
Added a validation check in extract_local() with a clear error message
pointing users to the fix. Also added a pytest suite since there were
no tests for this module.