From 2b68082a498b3e97bfc35fb90fc1e3f19390185d Mon Sep 17 00:00:00 2001 From: Hamid Adebayo Date: Thu, 26 Feb 2026 10:41:07 -0600 Subject: [PATCH] feat: add filter_criteria parameter for metadata-based tool filtering --- .../nodes/cuga_lite/cuga_lite_graph.py | 31 +++++++++++++++++-- src/cuga/sdk.py | 4 +++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py b/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py index d0b1d214..282569e7 100644 --- a/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py +++ b/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py @@ -329,6 +329,7 @@ async def create_find_tools_tool( all_tools: Sequence[StructuredTool], all_apps: List[Any], app_to_tools_map: Optional[Dict[str, List[StructuredTool]]] = None, + filter_criteria: Optional[Dict[str, Any]] = None, ) -> StructuredTool: """Create a find_tools StructuredTool for tool discovery. @@ -336,6 +337,7 @@ async def create_find_tools_tool( all_tools: All available tools to search through all_apps: All available app definitions app_to_tools_map: Optional mapping of app_name -> list of tools. If provided, used for filtering by app_name. + filter_criteria: Optional dictionary of metadata filters to apply to tools (e.g., {"domain": "hockey"}) Returns: StructuredTool configured for finding relevant tools @@ -359,6 +361,29 @@ async def find_tools_func(query: str, app_name: str): ) filtered_tools = [] + # Apply filter criteria if specified + if filter_criteria and filtered_tools: + criteria_filtered_tools = [] + for tool in filtered_tools: + if hasattr(tool, 'metadata') and tool.metadata: + # Check if all filter criteria match + matches_all = all( + tool.metadata.get(key) == value + for key, value in filter_criteria.items() + ) + if matches_all: + criteria_filtered_tools.append(tool) + + if criteria_filtered_tools: + filtered_tools = criteria_filtered_tools + criteria_str = ", ".join(f"{k}='{v}'" for k, v in filter_criteria.items()) + logger.info(f"Filtered {len(filtered_tools)} tools matching criteria: {criteria_str}") + else: + criteria_str = ", ".join(f"{k}='{v}'" for k, v in filter_criteria.items()) + logger.warning( + f"No tools found matching criteria ({criteria_str}) in app '{app_name}'. Using all {len(filtered_tools)} tools from app." + ) + filtered_apps = [app for app in all_apps if hasattr(app, 'name') and app.name == app_name] if not filtered_apps: @@ -453,6 +478,7 @@ def create_cuga_lite_graph( thread_id: Optional[str] = None, callbacks: Optional[List[BaseCallbackHandler]] = None, special_instructions: Optional[str] = None, + filter_criteria: Optional[Dict[str, Any]] = None, ) -> StateGraph: """ Create a unified CugaLite subgraph combining CodeAct and CugaAgent functionality. @@ -487,6 +513,7 @@ def create_prepare_node( base_instructions, tools_context_dict, base_special_instructions, + base_filter_criteria, ): """Factory to create prepare node with closure over tool provider and config.""" @@ -605,7 +632,7 @@ async def prepare_tools_and_apps( tools_for_prompt = tools_for_execution if enable_find_tools: find_tool = await create_find_tools_tool( - all_tools=tools_for_execution, all_apps=apps_for_prompt, app_to_tools_map=app_to_tools_map + all_tools=tools_for_execution, all_apps=apps_for_prompt, app_to_tools_map=app_to_tools_map, filter_criteria=base_filter_criteria ) tools_for_prompt = [find_tool] # Add find_tools to tools context for sandbox execution @@ -1106,7 +1133,7 @@ async def sandbox(state: CugaLiteState, config: Optional[RunnableConfig] = None) # Create node instances using factories prepare_node = create_prepare_node( - tool_provider, prompt_template, instructions, tools_context, special_instructions + tool_provider, prompt_template, instructions, tools_context, special_instructions, filter_criteria ) call_model_node = create_call_model_node(model, callbacks) sandbox_node = create_sandbox_node(tools_context, thread_id, apps_list) diff --git a/src/cuga/sdk.py b/src/cuga/sdk.py index f9ea080e..0c1f8015 100644 --- a/src/cuga/sdk.py +++ b/src/cuga/sdk.py @@ -1205,6 +1205,7 @@ def __init__( auto_load_policies: Optional[bool] = None, reset_policy_storage: bool = False, filesystem_sync: Optional[bool] = None, + filter_criteria: Optional[Dict[str, Any]] = None, ): """ Initialize the CUGA Agent. @@ -1220,6 +1221,7 @@ def __init__( auto_load_policies: If True, automatically loads policies from cuga_folder reset_policy_storage: If True, clears all existing policies from storage on init filesystem_sync: If True, saves policies to .cuga when added/updated (default: True) + filter_criteria: Optional dictionary of metadata criteria to filter tools (e.g., {"domain": "sports"}) Example with tool approval policy: ```python @@ -1252,6 +1254,7 @@ def __init__( self._compiled_graph = None self._policy_system = policy_system self._special_instructions = special_instructions + self._filter_criteria = filter_criteria # Use settings defaults if not provided self.cuga_folder = cuga_folder if cuga_folder is not None else settings.policy.cuga_folder @@ -1357,6 +1360,7 @@ def _create_hitl_wrapper_graph(self, thread_id: Optional[str] = None): thread_id=thread_id, callbacks=self._callbacks, special_instructions=self._special_instructions, + filter_criteria=self._filter_criteria, ) # Compile subgraph without checkpointer so it streams internal updates compiled_subgraph = cuga_lite_subgraph.compile()