1111from collections .abc import Coroutine
1212from dataclasses import dataclass
1313from importlib import metadata
14- from typing import Any , TypeVar
14+ from typing import Any , Literal , TypeVar
1515
1616from stackone_ai .models import (
1717 ExecuteConfig ,
2929
3030logger = logging .getLogger ("stackone.tools" )
3131
32+ SearchMode = Literal ["auto" , "semantic" , "local" ]
33+
3234try :
3335 _SDK_VERSION = metadata .version ("stackone-ai" )
3436except metadata .PackageNotFoundError : # pragma: no cover - best-effort fallback when running from source
@@ -246,8 +248,9 @@ class SearchTool:
246248 tools = search_tool("manage employee records") # returns Tools
247249 """
248250
249- def __init__ (self , toolset : StackOneToolSet ) -> None :
251+ def __init__ (self , toolset : StackOneToolSet , search : SearchMode = "auto" ) -> None :
250252 self ._toolset = toolset
253+ self ._search = search
251254
252255 def __call__ (
253256 self ,
@@ -257,6 +260,7 @@ def __call__(
257260 top_k : int | None = None ,
258261 min_similarity : float | None = None ,
259262 account_ids : list [str ] | None = None ,
263+ search : SearchMode | None = None ,
260264 ) -> Tools :
261265 """Search for tools using natural language.
262266
@@ -266,6 +270,7 @@ def __call__(
266270 top_k: Maximum number of tools to return
267271 min_similarity: Minimum similarity score threshold 0-1
268272 account_ids: Optional account IDs (uses set_accounts() if not provided)
273+ search: Override the default search mode for this call
269274
270275 Returns:
271276 Tools collection with matched tools
@@ -276,6 +281,7 @@ def __call__(
276281 top_k = top_k ,
277282 min_similarity = min_similarity ,
278283 account_ids = account_ids ,
284+ search = search if search is not None else self ._search ,
279285 )
280286
281287
@@ -325,13 +331,17 @@ def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
325331 self ._account_ids = account_ids
326332 return self
327333
328- def get_search_tool (self ) -> SearchTool :
334+ def get_search_tool (self , * , search : SearchMode = "auto" ) -> SearchTool :
329335 """Get a callable search tool that returns Tools collections.
330336
331337 Returns a callable that wraps :meth:`search_tools` for use in agent loops.
332338 The returned tool is directly callable: ``search_tool("query")`` returns
333339 :class:`Tools`.
334340
341+ Args:
342+ search: Default search mode for the returned tool. Can be overridden
343+ per-call. See :meth:`search_tools` for details.
344+
335345 Returns:
336346 SearchTool instance
337347
@@ -342,7 +352,7 @@ def get_search_tool(self) -> SearchTool:
342352 search_tool = toolset.get_search_tool()
343353 tools = search_tool("manage employee records")
344354 """
345- return SearchTool (self )
355+ return SearchTool (self , search = search )
346356
347357 @property
348358 def semantic_client (self ) -> SemanticSearchClient :
@@ -358,6 +368,38 @@ def semantic_client(self) -> SemanticSearchClient:
358368 )
359369 return self ._semantic_client
360370
371+ def _local_search (
372+ self ,
373+ query : str ,
374+ all_tools : Tools ,
375+ * ,
376+ connector : str | None = None ,
377+ top_k : int | None = None ,
378+ min_similarity : float | None = None ,
379+ ) -> Tools :
380+ """Run local BM25+TF-IDF search over already-fetched tools."""
381+ from stackone_ai .local_search import ToolIndex
382+
383+ available_connectors = all_tools .get_connectors ()
384+ if not available_connectors :
385+ return Tools ([])
386+
387+ index = ToolIndex (list (all_tools ))
388+ results = index .search (
389+ query ,
390+ limit = top_k if top_k is not None else 5 ,
391+ min_score = min_similarity if min_similarity is not None else 0.0 ,
392+ )
393+ matched_names = [r .name for r in results ]
394+ tool_map = {t .name : t for t in all_tools }
395+ filter_connectors = {connector .lower ()} if connector else available_connectors
396+ matched_tools = [
397+ tool_map [name ]
398+ for name in matched_names
399+ if name in tool_map and name .split ("_" )[0 ].lower () in filter_connectors
400+ ]
401+ return Tools (matched_tools [:top_k ] if top_k is not None else matched_tools )
402+
361403 def search_tools (
362404 self ,
363405 query : str ,
@@ -366,13 +408,11 @@ def search_tools(
366408 top_k : int | None = None ,
367409 min_similarity : float | None = None ,
368410 account_ids : list [str ] | None = None ,
369- fallback_to_local : bool = True ,
411+ search : SearchMode = "auto" ,
370412 ) -> Tools :
371- """Search for and fetch tools using semantic search.
413+ """Search for and fetch tools using semantic or local search.
372414
373- This method uses the StackOne semantic search API to find relevant tools
374- based on natural language queries. It optimizes results by filtering to
375- only connectors available in linked accounts.
415+ This method discovers relevant tools based on natural language queries.
376416
377417 Args:
378418 query: Natural language description of needed functionality
@@ -382,30 +422,35 @@ def search_tools(
382422 min_similarity: Minimum similarity score threshold 0-1. If not provided,
383423 the server uses its default.
384424 account_ids: Optional account IDs (uses set_accounts() if not provided)
385- fallback_to_local: If True, fall back to local BM25+TF-IDF search on API failure
425+ search: Search backend to use:
426+ - ``"auto"`` (default): try semantic search first, fall back to local
427+ BM25+TF-IDF if the API is unavailable.
428+ - ``"semantic"``: use only the semantic search API; raises
429+ ``SemanticSearchError`` on failure.
430+ - ``"local"``: use only local BM25+TF-IDF search (no API call to the
431+ semantic search endpoint).
386432
387433 Returns:
388- Tools collection with semantically matched tools from linked accounts
434+ Tools collection with matched tools from linked accounts
389435
390436 Raises:
391- SemanticSearchError: If the API call fails and fallback_to_local is False
437+ SemanticSearchError: If the API call fails and search is ``"semantic"``
392438
393439 Examples:
394- # Basic semantic search
440+ # Semantic search (default with local fallback)
395441 tools = toolset.search_tools("manage employee records", top_k=5)
396442
397- # Filter by connector with minimum similarity
443+ # Explicit semantic search
444+ tools = toolset.search_tools("manage employees", search="semantic")
445+
446+ # Local BM25+TF-IDF search
447+ tools = toolset.search_tools("manage employees", search="local")
448+
449+ # Filter by connector
398450 tools = toolset.search_tools(
399451 "create time off request",
400452 connector="bamboohr",
401- min_similarity=0.5
402- )
403-
404- # With account filtering
405- tools = toolset.search_tools(
406- "send message",
407- account_ids=["acc-123"],
408- top_k=3
453+ search="semantic",
409454 )
410455 """
411456 all_tools = self .fetch_tools (account_ids = account_ids )
@@ -414,16 +459,22 @@ def search_tools(
414459 if not available_connectors :
415460 return Tools ([])
416461
462+ # Local-only search — skip semantic API entirely
463+ if search == "local" :
464+ return self ._local_search (
465+ query , all_tools , connector = connector , top_k = top_k , min_similarity = min_similarity
466+ )
467+
417468 try :
418- # Step 2: Determine which connectors to search
469+ # Determine which connectors to search
419470 if connector :
420471 connectors_to_search = {connector .lower ()} & available_connectors
421472 if not connectors_to_search :
422473 return Tools ([])
423474 else :
424475 connectors_to_search = available_connectors
425476
426- # Step 3: Search each connector in parallel
477+ # Search each connector in parallel
427478 def _search_one (c : str ) -> list [SemanticSearchResult ]:
428479 resp = self .semantic_client .search (
429480 query = query , connector = c , top_k = top_k , min_similarity = min_similarity
@@ -445,15 +496,15 @@ def _search_one(c: str) -> list[SemanticSearchResult]:
445496 if not all_results and last_error is not None :
446497 raise last_error
447498
448- # Step 4: Sort by score, apply top_k
499+ # Sort by score, apply top_k
449500 all_results .sort (key = lambda r : r .similarity_score , reverse = True )
450501 if top_k is not None :
451502 all_results = all_results [:top_k ]
452503
453504 if not all_results :
454505 return Tools ([])
455506
456- # Step 5: Match back to fetched tool definitions
507+ # Match back to fetched tool definitions
457508 action_names = {_normalize_action_name (r .action_name ) for r in all_results }
458509 matched_tools = [t for t in all_tools if t .name in action_names ]
459510
@@ -464,28 +515,13 @@ def _search_one(c: str) -> list[SemanticSearchResult]:
464515 return Tools (matched_tools )
465516
466517 except SemanticSearchError as e :
467- if not fallback_to_local :
518+ if search == "semantic" :
468519 raise
469520
470521 logger .warning ("Semantic search failed (%s), falling back to local BM25+TF-IDF search" , e )
471-
472- from stackone_ai .local_search import ToolIndex
473-
474- index = ToolIndex (list (all_tools ))
475- results = index .search (
476- query ,
477- limit = top_k if top_k is not None else 5 ,
478- min_score = min_similarity if min_similarity is not None else 0.0 ,
522+ return self ._local_search (
523+ query , all_tools , connector = connector , top_k = top_k , min_similarity = min_similarity
479524 )
480- matched_names = [r .name for r in results ]
481- tool_map = {t .name : t for t in all_tools }
482- filter_connectors = {connector .lower ()} if connector else available_connectors
483- matched_tools = [
484- tool_map [name ]
485- for name in matched_names
486- if name in tool_map and name .split ("_" )[0 ].lower () in filter_connectors
487- ]
488- return Tools (matched_tools [:top_k ] if top_k is not None else matched_tools )
489525
490526 def search_action_names (
491527 self ,
0 commit comments