From 5a702ffdde3b1d99f9ae9c964d747a656ba6e92d Mon Sep 17 00:00:00 2001 From: Stephen Malina <1093790+an1lam@users.noreply.github.com> Date: Thu, 27 Mar 2025 14:28:20 -0400 Subject: [PATCH] Explicitly set reentrant to `False` for torch checkpointing (#1) * Explicitly set reentrant to False for torch checkpointing As discussed [here](https://pytorch.org/docs/2.6/checkpoint.html#torch.utils.checkpoint.checkpoint), torch 2.4 and newer require explicitly passing to the checkpointing function. Prior to torch 2.4 (e.g. 2.2), use_reentrant defaulted to True, however we have found that reentrant checkpointing does not work with DDP and torch.compile [1]. As a result, this change forces use_reentrant=False to enable us to use torch.compile with our structure models. * Update version * Update checkpointing.py * Update setup.py * Update setup.py --- openfold/utils/checkpointing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openfold/utils/checkpointing.py b/openfold/utils/checkpointing.py index b2bb752cd..3351e6607 100644 --- a/openfold/utils/checkpointing.py +++ b/openfold/utils/checkpointing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import importlib from typing import Any, Tuple, List, Callable, Optional @@ -34,7 +35,7 @@ def get_checkpoint_fn(): if(deepspeed_is_configured): checkpoint = deepspeed.checkpointing.checkpoint else: - checkpoint = torch.utils.checkpoint.checkpoint + checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) return checkpoint