Skip to content

Commit ec4446f

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Introduce _BranchPath utility class for dynamic execution branches
Implement _BranchPath, a hierarchical path representation for dot-separated dynamic execution branch strings (e.g. 'parent@1.child@2.node'). This class provides a unified, segment-based API for: - Parsing and serialization of branch strings. - Extracting all run IDs across the entire branch chain. - Robust hierarchy checks (is_descendant_of) that are immune to partial prefix matches. Also add comprehensive unit tests covering all features and boundary cases. Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 938209288
1 parent 4c4f77a commit ec4446f

5 files changed

Lines changed: 272 additions & 24 deletions

File tree

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Hierarchical path representation for dynamic execution branches."""
16+
17+
from __future__ import annotations
18+
19+
20+
class _BranchPath:
21+
"""Represents a hierarchical path for execution branches.
22+
23+
A path consists of dot-separated segments (e.g., 'segment1.segment2'),
24+
where each segment represents a node run and is typically formatted
25+
as 'node_name@run_id' or 'node_name'.
26+
27+
Example:
28+
'parent_agent@1.collect_user_info_tool@2.sub_workflow'
29+
"""
30+
31+
def __init__(self, segments: list[str]):
32+
"""Initializes a _BranchPath with a list of segments."""
33+
self._segments = list(segments)
34+
35+
@classmethod
36+
def from_string(cls, path_str: str | None) -> _BranchPath:
37+
"""Parses a _BranchPath from a dot-separated string representation."""
38+
if not path_str:
39+
return cls([])
40+
return cls(path_str.split('.'))
41+
42+
def __str__(self) -> str:
43+
"""Returns the dot-separated string representation of the path."""
44+
return '.'.join(self._segments)
45+
46+
def __eq__(self, other: object) -> bool:
47+
"""Returns True if segments are equal."""
48+
if not isinstance(other, _BranchPath):
49+
return NotImplemented
50+
return self._segments == other._segments
51+
52+
@property
53+
def segments(self) -> list[str]:
54+
"""Returns a copy of the path segments."""
55+
return list(self._segments)
56+
57+
@property
58+
def run_ids(self) -> set[str]:
59+
"""Extracts all run IDs (the part after '@') from all segments in the path.
60+
61+
Example:
62+
- Path: 'parent@1.child@2.node'
63+
- Returns: {'1', '2'}
64+
"""
65+
ids = set()
66+
for segment in self._segments:
67+
parts = segment.rsplit('@', 1)
68+
if len(parts) > 1 and parts[1]:
69+
ids.add(parts[1])
70+
return ids
71+
72+
@property
73+
def parent(self) -> _BranchPath | None:
74+
"""Returns the parent _BranchPath, or None if this is a root path."""
75+
if len(self._segments) <= 1:
76+
return None
77+
return _BranchPath(self._segments[:-1])
78+
79+
def is_descendant_of(self, ancestor: _BranchPath) -> bool:
80+
"""Checks if this path is a descendant of the ancestor path.
81+
82+
A path is a descendant if it starts with all segments of the ancestor path
83+
and has additional segments.
84+
"""
85+
if len(self._segments) <= len(ancestor._segments):
86+
return False
87+
return self._segments[: len(ancestor._segments)] == ancestor._segments
88+
89+
@staticmethod
90+
def common_prefix(paths: list[_BranchPath]) -> _BranchPath:
91+
"""Finds the common prefix of a list of _BranchPath objects."""
92+
if not paths:
93+
return _BranchPath([])
94+
95+
common_segments = []
96+
for segments in zip(*[p.segments for p in paths]):
97+
if len(set(segments)) == 1:
98+
common_segments.append(segments[0])
99+
else:
100+
break
101+
return _BranchPath(common_segments)

src/google/adk/flows/llm_flows/contents.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing_extensions import override
2424

2525
from ...agents.invocation_context import InvocationContext
26+
from ...events._branch_path import _BranchPath
2627
from ...events.event import Event
2728
from ...models.llm_request import LlmRequest
2829
from ._base_llm_processor import BaseLlmRequestProcessor
@@ -900,12 +901,10 @@ def _is_event_belongs_to_branch(
900901
"""
901902
if not invocation_branch or not event.branch:
902903
return True
903-
# We use dot to delimit branch nodes. To avoid simple prefix match
904-
# (e.g. agent_0 unexpectedly matching agent_00), require either perfect branch
905-
# match, or match prefix with an additional explicit '.'
906-
return invocation_branch == event.branch or invocation_branch.startswith(
907-
f'{event.branch}.'
908-
)
904+
905+
inv_path = _BranchPath.from_string(invocation_branch)
906+
evt_path = _BranchPath.from_string(event.branch)
907+
return inv_path == evt_path or inv_path.is_descendant_of(evt_path)
909908

910909

911910
def _is_function_call_event(event: Event, function_name: str) -> bool:

src/google/adk/workflow/_join_node.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing_extensions import override
2424

2525
from ..agents.context import Context
26+
from ..events._branch_path import _BranchPath
2627
from ..events.event import Event
2728
from ._base_node import BaseNode
2829

@@ -33,15 +34,8 @@ def _get_common_branch_prefix(branches: list[str]) -> str:
3334
"""Find the common prefix of dot-separated branch strings."""
3435
if not branches:
3536
return ''
36-
split_branches = [b.split('.') if b else [] for b in branches]
37-
38-
common = []
39-
for segments in zip(*split_branches):
40-
if len(set(segments)) == 1:
41-
common.append(segments[0])
42-
else:
43-
break
44-
return '.'.join(common)
37+
paths = [_BranchPath.from_string(b) for b in branches]
38+
return str(_BranchPath.common_prefix(paths))
4539

4640

4741
class JoinNode(BaseNode):

src/google/adk/workflow/_workflow.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from pydantic import Field
3232

33+
from ..events._branch_path import _BranchPath
3334
from ._base_node import BaseNode
3435
from ._base_node import START
3536
from ._dynamic_node_scheduler import DynamicNodeScheduler
@@ -58,15 +59,8 @@ def get_common_branch_prefix(branches: list[str]) -> str:
5859
"""Find the common prefix of dot-separated branch strings."""
5960
if not branches:
6061
return ''
61-
split_branches = [b.split('.') if b else [] for b in branches]
62-
63-
common = []
64-
for segments in zip(*split_branches):
65-
if len(set(segments)) == 1:
66-
common.append(segments[0])
67-
else:
68-
break
69-
return '.'.join(common)
62+
paths = [_BranchPath.from_string(b) for b in branches]
63+
return str(_BranchPath.common_prefix(paths))
7064

7165

7266
# ---------------------------------------------------------------------------
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for _BranchPath.
16+
17+
Verifies that _BranchPath correctly parses, serializes, and manipulates
18+
hierarchical dynamic execution branch paths.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
from google.adk.events._branch_path import _BranchPath
24+
import pytest
25+
26+
27+
def test_from_string_with_empty_string_returns_empty_path():
28+
"""Parsing an empty string returns a _BranchPath with no segments."""
29+
path = _BranchPath.from_string("")
30+
31+
assert path.segments == []
32+
assert str(path) == ""
33+
34+
35+
def test_from_string_with_single_segment_returns_path_with_one_segment():
36+
"""Parsing a single name returns a _BranchPath with one segment."""
37+
path = _BranchPath.from_string("agent_0")
38+
39+
assert path.segments == ["agent_0"]
40+
assert str(path) == "agent_0"
41+
42+
43+
def test_from_string_with_multiple_segments_returns_path_with_all_segments():
44+
"""Parsing a dot-separated string returns a _BranchPath with all segments."""
45+
path = _BranchPath.from_string("parent.child.node")
46+
47+
assert path.segments == ["parent", "child", "node"]
48+
assert str(path) == "parent.child.node"
49+
50+
51+
def test_equality_compares_path_segments():
52+
"""Two _BranchPath objects are equal if and only if their segments match."""
53+
path1 = _BranchPath.from_string("parent.child")
54+
path2 = _BranchPath.from_string("parent.child")
55+
path3 = _BranchPath.from_string("parent.other")
56+
57+
assert path1 == path2
58+
assert path1 != path3
59+
assert path1 != "parent.child" # Different type
60+
61+
62+
def test_run_ids_extracts_all_run_ids_from_path():
63+
"""run_ids extracts all run IDs (the part after '@') from all segments."""
64+
# Given paths with various run ID patterns
65+
path_with_ids = _BranchPath.from_string("parent@1.child@2.node")
66+
path_no_ids = _BranchPath.from_string("parent.child")
67+
path_mixed = _BranchPath.from_string("parent@1.child.node@3")
68+
69+
# Then the extracted run IDs match expectations
70+
assert path_with_ids.run_ids == {"1", "2"}
71+
assert path_no_ids.run_ids == set()
72+
assert path_mixed.run_ids == {"1", "3"}
73+
74+
75+
def test_parent_returns_parent_path_or_none_for_root():
76+
"""parent returns a new _BranchPath excluding the leaf segment, or None."""
77+
path = _BranchPath.from_string("parent.child.node")
78+
79+
assert path.parent == _BranchPath.from_string("parent.child")
80+
assert path.parent.parent == _BranchPath.from_string("parent")
81+
assert path.parent.parent.parent is None
82+
83+
84+
def test_is_descendant_of_verifies_path_hierarchy_safely():
85+
"""is_descendant_of returns True if the path is a strict sub-path of ancestor."""
86+
# Given an ancestor and various comparison paths
87+
ancestor = _BranchPath.from_string("parent.child")
88+
descendant = _BranchPath.from_string("parent.child.node.leaf")
89+
not_descendant = _BranchPath.from_string("parent.other")
90+
same = _BranchPath.from_string("parent.child")
91+
92+
# Then descendant checks match expectations
93+
assert descendant.is_descendant_of(ancestor)
94+
assert not ancestor.is_descendant_of(descendant)
95+
assert not not_descendant.is_descendant_of(ancestor)
96+
assert not same.is_descendant_of(ancestor)
97+
98+
99+
def test_is_descendant_of_is_immune_to_partial_name_prefix_match():
100+
"""is_descendant_of compares segments, avoiding partial string prefix bugs."""
101+
# Given an ancestor and a path that has a partial string prefix but different segment
102+
ancestor = _BranchPath.from_string("agent_0")
103+
descendant_with_prefix = _BranchPath.from_string("agent_00.child")
104+
105+
# Then it is not recognized as a descendant because segments don't match
106+
assert not descendant_with_prefix.is_descendant_of(ancestor)
107+
108+
109+
def test_common_prefix_finds_longest_shared_path():
110+
"""common_prefix returns the longest common prefix of a list of paths."""
111+
# Given a list of paths sharing a common prefix
112+
paths = [
113+
_BranchPath.from_string("parent.child.node1"),
114+
_BranchPath.from_string("parent.child.node2.leaf"),
115+
_BranchPath.from_string("parent.child.node3"),
116+
]
117+
118+
# When finding the common prefix
119+
result = _BranchPath.common_prefix(paths)
120+
121+
# Then the result matches the shared parent path
122+
assert result == _BranchPath.from_string("parent.child")
123+
124+
125+
def test_common_prefix_with_no_shared_path_returns_empty():
126+
"""common_prefix returns an empty path if there is no shared prefix."""
127+
paths = [
128+
_BranchPath.from_string("parent.child"),
129+
_BranchPath.from_string("other.child"),
130+
]
131+
132+
result = _BranchPath.common_prefix(paths)
133+
134+
assert result == _BranchPath.from_string("")
135+
136+
137+
def test_common_prefix_with_empty_list_returns_empty():
138+
"""common_prefix returns an empty path if the input list is empty."""
139+
result = _BranchPath.common_prefix([])
140+
141+
assert result == _BranchPath.from_string("")
142+
143+
144+
def test_constructor_copies_segments_list():
145+
"""_BranchPath copies the input segments list to ensure immutability."""
146+
segments = ["parent", "child"]
147+
path = _BranchPath(segments)
148+
149+
# Mutate the original list
150+
segments.append("grandchild")
151+
152+
# The path segments should remain unchanged
153+
assert path.segments == ["parent", "child"]
154+
155+
156+
def test_run_ids_filters_out_empty_run_ids():
157+
"""run_ids filters out segments with empty run IDs (e.g. ending with '@')."""
158+
path = _BranchPath.from_string("parent@.child@2.node@")
159+
160+
assert path.run_ids == {"2"}

0 commit comments

Comments
 (0)