Skip to content

Skip _check_forward_args validation during make_fx tracing#454

Open
IvanKobzarev wants to merge 1 commit into
mainfrom
assert_fix
Open

Skip _check_forward_args validation during make_fx tracing#454
IvanKobzarev wants to merge 1 commit into
mainfrom
assert_fix

Conversation

@IvanKobzarev
Copy link
Copy Markdown
Contributor

During make_fx tracing, the forward args are FakeTensors which have different device/repr than the meta tensors captured during AutoParallel tracing. This causes _check_forward_args to raise spurious validation errors. Detect this by checking for meta tensors or FakeTensors in the args and skip validation in that case.

Authored by Claude.

During make_fx tracing, the forward args are FakeTensors which have
different device/repr than the meta tensors captured during AutoParallel
tracing. This causes _check_forward_args to raise spurious validation
errors. Detect this by checking for meta tensors or FakeTensors in
the args and skip validation in that case.

Authored by Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 11, 2026
expected_inputs: Expected shapes from _compute_expected_inputs.
dynamic_dims: Set of (arg_index, dim) pairs for dynamic dimensions.
"""
# Skip validation during make_fx tracing -- FakeTensors have different
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally this check only runs after AutoParallel has finished, as is meant as a safety for users calling the parallelized model in their code.

When we do torch.compile(parallel_mod), this check should normally be properly captured.

So I'm not sure if this skipping is actually doing what we want globally, but it would be good to discuss with others from the GraphTrainer if they should just completely skip this in their implementation of AutoParallelGraph

@fmassa
Copy link
Copy Markdown
Contributor

fmassa commented May 11, 2026

I think I'd instead just change the AutoParallelGraph to not call into _check_forward_args, i.e., removing https://github.com/pytorch/torchtitan/blob/d107aa4862a9a266243b037037130e8d8fa68b97/torchtitan/experiments/graph_trainer/autoparallel_api.py#L129

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants