Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,15 @@ async def create_find_tools_tool(
all_tools: Sequence[StructuredTool],
all_apps: List[Any],
app_to_tools_map: Optional[Dict[str, List[StructuredTool]]] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
) -> StructuredTool:
"""Create a find_tools StructuredTool for tool discovery.

Args:
all_tools: All available tools to search through
all_apps: All available app definitions
app_to_tools_map: Optional mapping of app_name -> list of tools. If provided, used for filtering by app_name.
filter_criteria: Optional dictionary of metadata filters to apply to tools (e.g., {"domain": "hockey"})

Returns:
StructuredTool configured for finding relevant tools
Expand All @@ -359,6 +361,29 @@ async def find_tools_func(query: str, app_name: str):
)
filtered_tools = []

# Apply filter criteria if specified
if filter_criteria and filtered_tools:
criteria_filtered_tools = []
for tool in filtered_tools:
if hasattr(tool, 'metadata') and tool.metadata:
# Check if all filter criteria match
matches_all = all(
tool.metadata.get(key) == value
for key, value in filter_criteria.items()
)
if matches_all:
criteria_filtered_tools.append(tool)

if criteria_filtered_tools:
filtered_tools = criteria_filtered_tools
criteria_str = ", ".join(f"{k}='{v}'" for k, v in filter_criteria.items())
logger.info(f"Filtered {len(filtered_tools)} tools matching criteria: {criteria_str}")
else:
criteria_str = ", ".join(f"{k}='{v}'" for k, v in filter_criteria.items())
logger.warning(
f"No tools found matching criteria ({criteria_str}) in app '{app_name}'. Using all {len(filtered_tools)} tools from app."
)

filtered_apps = [app for app in all_apps if hasattr(app, 'name') and app.name == app_name]

if not filtered_apps:
Expand Down Expand Up @@ -453,6 +478,7 @@ def create_cuga_lite_graph(
thread_id: Optional[str] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
special_instructions: Optional[str] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
) -> StateGraph:
"""
Create a unified CugaLite subgraph combining CodeAct and CugaAgent functionality.
Expand Down Expand Up @@ -487,6 +513,7 @@ def create_prepare_node(
base_instructions,
tools_context_dict,
base_special_instructions,
base_filter_criteria,
):
"""Factory to create prepare node with closure over tool provider and config."""

Expand Down Expand Up @@ -605,7 +632,7 @@ async def prepare_tools_and_apps(
tools_for_prompt = tools_for_execution
if enable_find_tools:
find_tool = await create_find_tools_tool(
all_tools=tools_for_execution, all_apps=apps_for_prompt, app_to_tools_map=app_to_tools_map
all_tools=tools_for_execution, all_apps=apps_for_prompt, app_to_tools_map=app_to_tools_map, filter_criteria=base_filter_criteria
)
tools_for_prompt = [find_tool]
# Add find_tools to tools context for sandbox execution
Expand Down Expand Up @@ -1106,7 +1133,7 @@ async def sandbox(state: CugaLiteState, config: Optional[RunnableConfig] = None)

# Create node instances using factories
prepare_node = create_prepare_node(
tool_provider, prompt_template, instructions, tools_context, special_instructions
tool_provider, prompt_template, instructions, tools_context, special_instructions, filter_criteria
)
call_model_node = create_call_model_node(model, callbacks)
sandbox_node = create_sandbox_node(tools_context, thread_id, apps_list)
Expand Down
4 changes: 4 additions & 0 deletions src/cuga/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,7 @@ def __init__(
auto_load_policies: Optional[bool] = None,
reset_policy_storage: bool = False,
filesystem_sync: Optional[bool] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
):
"""
Initialize the CUGA Agent.
Expand All @@ -1220,6 +1221,7 @@ def __init__(
auto_load_policies: If True, automatically loads policies from cuga_folder
reset_policy_storage: If True, clears all existing policies from storage on init
filesystem_sync: If True, saves policies to .cuga when added/updated (default: True)
filter_criteria: Optional dictionary of metadata criteria to filter tools (e.g., {"domain": "sports"})

Example with tool approval policy:
```python
Expand Down Expand Up @@ -1252,6 +1254,7 @@ def __init__(
self._compiled_graph = None
self._policy_system = policy_system
self._special_instructions = special_instructions
self._filter_criteria = filter_criteria

# Use settings defaults if not provided
self.cuga_folder = cuga_folder if cuga_folder is not None else settings.policy.cuga_folder
Expand Down Expand Up @@ -1357,6 +1360,7 @@ def _create_hitl_wrapper_graph(self, thread_id: Optional[str] = None):
thread_id=thread_id,
callbacks=self._callbacks,
special_instructions=self._special_instructions,
filter_criteria=self._filter_criteria,
)
# Compile subgraph without checkpointer so it streams internal updates
compiled_subgraph = cuga_lite_subgraph.compile()
Expand Down