Skip to content

Preserve DeviceMesh identity in expand_rule's op_schema deepcopy#488

Open
fmassa wants to merge 5 commits into
mainfrom
fmassa/preserve_devicemesh_identity
Open

Preserve DeviceMesh identity in expand_rule's op_schema deepcopy#488
fmassa wants to merge 5 commits into
mainfrom
fmassa/preserve_devicemesh_identity

Conversation

@fmassa

@fmassa fmassa commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before mutating it. DeviceMesh has no deepcopy, so deepcopy went through getstate/setstate and produced a fresh DeviceMesh with the same logical content but an empty _flatten_mapping cache. The DTensorSpecs returned from expand_rule carried these duplicates, which propagated into the sharding solution.

apply_placement's pre-warming in _apply_placement_common only populates the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched as_strided on the real rank_map — failing FakeTensorMode's non-fake-input check. Which solution the solver picked depended on the cost model, so the failure surfaced on g5/A10G CI but not on H100.

Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with DeviceMesh identity mappings so duplicates aren't produced. Adds a regression test that asserts every spec mesh's root has a warm _flatten cache after apply_placement.

Authored with Claude.

expand_rule used copy.deepcopy(op_schema_) to snapshot the schema before
mutating it. DeviceMesh has no __deepcopy__, so deepcopy went through
__getstate__/__setstate__ and produced a fresh DeviceMesh with the same
logical content but an empty _flatten_mapping cache. The DTensorSpecs
returned from expand_rule carried these duplicates, which propagated
into the sharding solution.

apply_placement's pre-warming in _apply_placement_common only populates
the user mesh's cache. When _optimize_same_nd_sharding_as_1d inside
make_fx hit a duplicate mesh, _flatten() cache-missed and dispatched
as_strided on the real rank_map — failing FakeTensorMode's non-fake-input
check. Which solution the solver picked depended on the cost model, so
the failure surfaced on g5/A10G CI but not on H100.

Fix: _deepcopy_preserving_mesh pre-seeds copy.deepcopy's memo with
DeviceMesh identity mappings so duplicates aren't produced. Adds a
regression test that asserts every spec mesh's root has a warm
_flatten cache after apply_placement.

Authored with Claude.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant