@@ -124,7 +124,8 @@ def compute_layout(self, graph: Any) -> dict[Any, tuple[float, float]]:
124124
125125 root_node = self ._select_root (graph )
126126 layered_nodes = self ._group_by_level (graph , root_node )
127- return self ._build_positions (layered_nodes )
127+ ordered_layers = self ._order_layers (graph , layered_nodes )
128+ return self ._build_positions (ordered_layers )
128129
129130 def _select_root (self , graph : Any ) -> Any :
130131 """Select a root node for layout.
@@ -171,6 +172,56 @@ def _group_by_level(self, graph: Any, root_node: Any) -> dict[int, list[Any]]:
171172 layered_nodes [orphan_layer ] = orphans
172173 return layered_nodes
173174
175+ def _order_layers (
176+ self , graph : Any , layered_nodes : dict [int , list [Any ]]
177+ ) -> dict [int , list [Any ]]:
178+ """Order nodes within each layer to keep parents above their children.
179+
180+ Reorders the nodes in each layer to improve visual readability when rendering a
181+ layered graph layout. The first layer is ordered deterministically (by string
182+ representation). For subsequent layers, nodes are ordered primarily by the
183+ lowest index of any already-ordered parent (predecessor) from the previous
184+ layers, with a deterministic string-based tie-breaker.
185+
186+ Args:
187+ graph:
188+ Graph-like object providing predecessor relationships via
189+ ``graph.predecessors(node)``.
190+ layered_nodes:
191+ Mapping from layer index to the list of nodes assigned to that layer.
192+
193+ Returns:
194+ A new mapping with the same layer keys as ``layered_nodes``, where each
195+ layer's node list is ordered to align children beneath their parents.
196+ """
197+ ordered_layers : dict [int , list [Any ]] = {}
198+ previous_order : dict [Any , int ] = {}
199+
200+ for level in sorted (layered_nodes ):
201+ nodes = layered_nodes [level ]
202+ if level == 0 :
203+ ordered_nodes = sorted (nodes , key = str )
204+ else :
205+
206+ def sort_key (node : Any ) -> tuple [int , str ]:
207+ parents = [
208+ parent
209+ for parent in graph .predecessors (node )
210+ if parent in previous_order
211+ ]
212+ if parents :
213+ parent_index = min (previous_order [parent ] for parent in parents )
214+ else :
215+ parent_index = len (previous_order )
216+ return (parent_index , str (node ))
217+
218+ ordered_nodes = sorted (nodes , key = sort_key )
219+
220+ ordered_layers [level ] = ordered_nodes
221+ previous_order = {node : idx for idx , node in enumerate (ordered_nodes )}
222+
223+ return ordered_layers
224+
174225 def _build_positions (
175226 self , layered_nodes : dict [int , list [Any ]]
176227 ) -> dict [Any , tuple [float , float ]]:
@@ -187,11 +238,11 @@ def _build_positions(
187238 """
188239 x_spacing , y_spacing = 2.5 , 1.2
189240 pos : dict [Any , tuple [float , float ]] = {}
190- for level , nodes in layered_nodes .items ():
241+ for level in sorted (layered_nodes ):
242+ nodes = layered_nodes [level ]
191243 x_pos = level * x_spacing
192- sorted_nodes = sorted (nodes , key = str )
193- y_start = (len (sorted_nodes ) - 1 ) * y_spacing / 2
194- for i , node in enumerate (sorted_nodes ):
244+ y_start = (len (nodes ) - 1 ) * y_spacing / 2
245+ for i , node in enumerate (nodes ):
195246 pos [node ] = (x_pos , y_start - (i * y_spacing ))
196247 return pos
197248
0 commit comments