From c7f3865788076b1efc48bea717226d5dbf6022fc Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 19 Aug 2025 15:44:51 +0800 Subject: [PATCH] Update tree_util.py --- penzai/core/tree_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/penzai/core/tree_util.py b/penzai/core/tree_util.py index 0047eb5..d115c3d 100644 --- a/penzai/core/tree_util.py +++ b/penzai/core/tree_util.py @@ -49,8 +49,7 @@ def tree_flatten_exactly_one_level( paths_and_subtrees, treedef = jax.tree_util.tree_flatten_with_path( tree, is_leaf=lambda subtree: subtree is not tree ) - leaf_treedef = jax.tree_util.tree_structure(1) - if treedef == leaf_treedef: + if jax.tree_util.treedef_is_leaf(treedef): return None keys_and_subtrees = [