Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions evalbench/generators/models/gcp_data_engineering_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def _get_and_refresh_token(self) -> str:
return token_val


def _is_iterable(obj: Any) -> bool:
"""Safely checks if an object is iterable, excluding strings and bytes."""
if isinstance(obj, (str, bytes)):
return False
try:
iter(obj)
return True
except TypeError:
return False


def _extract_message_text(msg: Any) -> str:
"""Extracts agent message text from a single Message object."""
if getattr(msg, "role", None) != pb.ROLE_AGENT:
Expand All @@ -141,51 +152,49 @@ def _extract_message_text(msg: Any) -> str:


def _find_agent_text_recursive(obj: Any) -> str:
"""Recursively searches obj to find the first valid agent text."""
text = _extract_message_text(obj)
if text:
return text
"""Recursively searches obj to find and accumulate agent texts."""
texts = []

self_text = _extract_message_text(obj)
if self_text:
texts.append(self_text)

# 1. Handle dict-like mappings by traversing their values
if isinstance(obj, collections.abc.Mapping):
for val in obj.values():
text = _find_agent_text_recursive(val)
if text:
return text
texts.append(text)
return "\n\n".join(texts)

# 2. Handle iterables (exclude string/bytes)
is_iterable = False
if not isinstance(obj, (str, bytes)):
try:
iter(obj)
is_iterable = True
except TypeError as e:
logger.info("Object is not iterable: %s", e)

if is_iterable:
if _is_iterable(obj):
for item in obj:
text = _find_agent_text_recursive(item)
if text:
return text
texts.append(text)
return "\n\n".join(texts)

# 3. Handle standard Protobuf Messages via ListFields
elif hasattr(obj, "ListFields"):
try:
for field_desc, field_value in obj.ListFields():
text = _find_agent_text_recursive(field_value)
if text:
return text
texts.append(text)
except Exception:
pass
return "\n\n".join(texts)

# 4. Fallback for other standard objects
elif hasattr(obj, "__dict__"):
for val in obj.__dict__.values():
text = _find_agent_text_recursive(val)
if text:
return text
texts.append(text)
return "\n\n".join(texts)

return ""
return "\n\n".join(texts)


class DataEngineeringAgentGenerator(QueryGenerator):
Expand Down Expand Up @@ -336,7 +345,7 @@ async def _run_client(
message_req.metadata[CONVERSATION_TOKEN_URI] = token

context = ClientCallContext(
timeout=180.0,
timeout=300.0,
service_parameters={
"A2A-Extensions": ALL_EXTENSIONS
}
Expand Down
38 changes: 38 additions & 0 deletions evalbench/test/gcp_data_engineering_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,41 @@ def test_extract_reply_text_heuristic_recursive_fallback():

container = [{"some_key": "some_value"}, [nested_msg]]
assert _find_agent_text_recursive(container) == "Nested message text"


def test_find_agent_text_recursive_accumulation():
# 1. Test multiple agent messages are accumulated
# and joined by double newlines
nested_msg1 = pb.Message(role=pb.ROLE_AGENT)
nested_msg1.parts.append(pb.Part(text="Part 1 text"))
nested_msg2 = pb.Message(role=pb.ROLE_AGENT)
nested_msg2.parts.append(pb.Part(text="Part 2 text"))

container = [
nested_msg1,
{"some_other_key": nested_msg2},
]
expected = "Part 1 text\n\nPart 2 text"
assert _find_agent_text_recursive(container) == expected

# 2. Test resilience against non-iterable primitives
# (should skip them and not crash)
class NonIterableObject:
pass

nested_msg3 = pb.Message(role=pb.ROLE_AGENT)
nested_msg3.parts.append(pb.Part(text="Valid text"))

mixed_container = {
"number": 42,
"flag": True,
"none_value": None,
"custom_obj": NonIterableObject(),
"nested": nested_msg3,
}
assert _find_agent_text_recursive(mixed_container) == "Valid text"

# 3. Test empty/edge cases return empty string
assert _find_agent_text_recursive(None) == ""
assert _find_agent_text_recursive([]) == ""
assert _find_agent_text_recursive({}) == ""
Loading