Skip _check_forward_args validation during make_fx tracing#454
Open
IvanKobzarev wants to merge 1 commit into
Open
Skip _check_forward_args validation during make_fx tracing#454IvanKobzarev wants to merge 1 commit into
IvanKobzarev wants to merge 1 commit into
Conversation
During make_fx tracing, the forward args are FakeTensors which have different device/repr than the meta tensors captured during AutoParallel tracing. This causes _check_forward_args to raise spurious validation errors. Detect this by checking for meta tensors or FakeTensors in the args and skip validation in that case. Authored by Claude.
fmassa
reviewed
May 11, 2026
| expected_inputs: Expected shapes from _compute_expected_inputs. | ||
| dynamic_dims: Set of (arg_index, dim) pairs for dynamic dimensions. | ||
| """ | ||
| # Skip validation during make_fx tracing -- FakeTensors have different |
Contributor
There was a problem hiding this comment.
Normally this check only runs after AutoParallel has finished, as is meant as a safety for users calling the parallelized model in their code.
When we do torch.compile(parallel_mod), this check should normally be properly captured.
So I'm not sure if this skipping is actually doing what we want globally, but it would be good to discuss with others from the GraphTrainer if they should just completely skip this in their implementation of AutoParallelGraph
Contributor
|
I think I'd instead just change the AutoParallelGraph to not call into |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
During make_fx tracing, the forward args are FakeTensors which have different device/repr than the meta tensors captured during AutoParallel tracing. This causes _check_forward_args to raise spurious validation errors. Detect this by checking for meta tensors or FakeTensors in the args and skip validation in that case.
Authored by Claude.