Skip to content

Commit 4785d87

Browse files
Add the search modes for local, semantic and auto with example
1 parent f6920c8 commit 4785d87

File tree

5 files changed

+237
-49
lines changed

5 files changed

+237
-49
lines changed

examples/search_tool_example.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,47 @@ def example_search_tool_basic():
5959
print()
6060

6161

62+
def example_search_modes():
63+
"""Comparing semantic vs local search modes.
64+
65+
The search parameter controls which backend search_tools() uses:
66+
- "semantic": cloud-based semantic vector search (higher accuracy for natural language)
67+
- "local": local BM25+TF-IDF hybrid search (no network call to semantic API)
68+
- "auto" (default): tries semantic first, falls back to local on failure
69+
"""
70+
print("Example 2: Semantic vs local search modes\n")
71+
72+
toolset = StackOneToolSet()
73+
query = "manage employee time off"
74+
75+
# Semantic search — uses StackOne's semantic search API
76+
print('search="semantic": cloud-based semantic vector search')
77+
try:
78+
tools_semantic = toolset.search_tools(query, account_ids=_account_ids, top_k=5, search="semantic")
79+
print(f" Found {len(tools_semantic)} tools:")
80+
for tool in tools_semantic:
81+
print(f" - {tool.name}")
82+
except Exception as e:
83+
print(f" Semantic search unavailable: {e}")
84+
print()
85+
86+
# Local search — BM25+TF-IDF, no semantic API call
87+
print('search="local": local BM25+TF-IDF hybrid search')
88+
tools_local = toolset.search_tools(query, account_ids=_account_ids, top_k=5, search="local")
89+
print(f" Found {len(tools_local)} tools:")
90+
for tool in tools_local:
91+
print(f" - {tool.name}")
92+
print()
93+
94+
# Auto (default) — tries semantic, falls back to local
95+
print('search="auto" (default): semantic with local fallback')
96+
tools_auto = toolset.search_tools(query, account_ids=_account_ids, top_k=5, search="auto")
97+
print(f" Found {len(tools_auto)} tools:")
98+
for tool in tools_auto:
99+
print(f" - {tool.name}")
100+
print()
101+
102+
62103
def example_search_tool_with_execution():
63104
"""Example of discovering and executing tools dynamically"""
64105
print("Example 2: Dynamic tool execution\n")
@@ -211,6 +252,7 @@ def main():
211252

212253
# Basic examples that work without external APIs
213254
example_search_tool_basic()
255+
example_search_modes()
214256
example_search_tool_with_execution()
215257

216258
# Examples that require OpenAI API

stackone_ai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
SemanticSearchResponse,
88
SemanticSearchResult,
99
)
10-
from stackone_ai.toolset import SearchTool, StackOneToolSet
10+
from stackone_ai.toolset import SearchMode, SearchTool, StackOneToolSet
1111

1212
__all__ = [
1313
"StackOneToolSet",
1414
"StackOneTool",
1515
"Tools",
16+
"SearchMode",
1617
"SearchTool",
1718
# Semantic search
1819
"SemanticSearchClient",

stackone_ai/semantic_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
2828
If the semantic API is unavailable, the SDK falls back to a local
2929
BM25 + TF-IDF hybrid search over the fetched tools (unless
30-
``fallback_to_local=False``).
30+
``search="semantic"`` is specified).
3131
3232
3333
2. ``search_action_names(query)`` — Lightweight discovery

stackone_ai/toolset.py

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Coroutine
1212
from dataclasses import dataclass
1313
from importlib import metadata
14-
from typing import Any, TypeVar
14+
from typing import Any, Literal, TypeVar
1515

1616
from stackone_ai.models import (
1717
ExecuteConfig,
@@ -29,6 +29,8 @@
2929

3030
logger = logging.getLogger("stackone.tools")
3131

32+
SearchMode = Literal["auto", "semantic", "local"]
33+
3234
try:
3335
_SDK_VERSION = metadata.version("stackone-ai")
3436
except metadata.PackageNotFoundError: # pragma: no cover - best-effort fallback when running from source
@@ -246,8 +248,9 @@ class SearchTool:
246248
tools = search_tool("manage employee records") # returns Tools
247249
"""
248250

249-
def __init__(self, toolset: StackOneToolSet) -> None:
251+
def __init__(self, toolset: StackOneToolSet, search: SearchMode = "auto") -> None:
250252
self._toolset = toolset
253+
self._search = search
251254

252255
def __call__(
253256
self,
@@ -257,6 +260,7 @@ def __call__(
257260
top_k: int | None = None,
258261
min_similarity: float | None = None,
259262
account_ids: list[str] | None = None,
263+
search: SearchMode | None = None,
260264
) -> Tools:
261265
"""Search for tools using natural language.
262266
@@ -266,6 +270,7 @@ def __call__(
266270
top_k: Maximum number of tools to return
267271
min_similarity: Minimum similarity score threshold 0-1
268272
account_ids: Optional account IDs (uses set_accounts() if not provided)
273+
search: Override the default search mode for this call
269274
270275
Returns:
271276
Tools collection with matched tools
@@ -276,6 +281,7 @@ def __call__(
276281
top_k=top_k,
277282
min_similarity=min_similarity,
278283
account_ids=account_ids,
284+
search=search if search is not None else self._search,
279285
)
280286

281287

@@ -325,13 +331,17 @@ def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
325331
self._account_ids = account_ids
326332
return self
327333

328-
def get_search_tool(self) -> SearchTool:
334+
def get_search_tool(self, *, search: SearchMode = "auto") -> SearchTool:
329335
"""Get a callable search tool that returns Tools collections.
330336
331337
Returns a callable that wraps :meth:`search_tools` for use in agent loops.
332338
The returned tool is directly callable: ``search_tool("query")`` returns
333339
:class:`Tools`.
334340
341+
Args:
342+
search: Default search mode for the returned tool. Can be overridden
343+
per-call. See :meth:`search_tools` for details.
344+
335345
Returns:
336346
SearchTool instance
337347
@@ -342,7 +352,7 @@ def get_search_tool(self) -> SearchTool:
342352
search_tool = toolset.get_search_tool()
343353
tools = search_tool("manage employee records")
344354
"""
345-
return SearchTool(self)
355+
return SearchTool(self, search=search)
346356

347357
@property
348358
def semantic_client(self) -> SemanticSearchClient:
@@ -358,6 +368,38 @@ def semantic_client(self) -> SemanticSearchClient:
358368
)
359369
return self._semantic_client
360370

371+
def _local_search(
372+
self,
373+
query: str,
374+
all_tools: Tools,
375+
*,
376+
connector: str | None = None,
377+
top_k: int | None = None,
378+
min_similarity: float | None = None,
379+
) -> Tools:
380+
"""Run local BM25+TF-IDF search over already-fetched tools."""
381+
from stackone_ai.local_search import ToolIndex
382+
383+
available_connectors = all_tools.get_connectors()
384+
if not available_connectors:
385+
return Tools([])
386+
387+
index = ToolIndex(list(all_tools))
388+
results = index.search(
389+
query,
390+
limit=top_k if top_k is not None else 5,
391+
min_score=min_similarity if min_similarity is not None else 0.0,
392+
)
393+
matched_names = [r.name for r in results]
394+
tool_map = {t.name: t for t in all_tools}
395+
filter_connectors = {connector.lower()} if connector else available_connectors
396+
matched_tools = [
397+
tool_map[name]
398+
for name in matched_names
399+
if name in tool_map and name.split("_")[0].lower() in filter_connectors
400+
]
401+
return Tools(matched_tools[:top_k] if top_k is not None else matched_tools)
402+
361403
def search_tools(
362404
self,
363405
query: str,
@@ -366,13 +408,11 @@ def search_tools(
366408
top_k: int | None = None,
367409
min_similarity: float | None = None,
368410
account_ids: list[str] | None = None,
369-
fallback_to_local: bool = True,
411+
search: SearchMode = "auto",
370412
) -> Tools:
371-
"""Search for and fetch tools using semantic search.
413+
"""Search for and fetch tools using semantic or local search.
372414
373-
This method uses the StackOne semantic search API to find relevant tools
374-
based on natural language queries. It optimizes results by filtering to
375-
only connectors available in linked accounts.
415+
This method discovers relevant tools based on natural language queries.
376416
377417
Args:
378418
query: Natural language description of needed functionality
@@ -382,30 +422,35 @@ def search_tools(
382422
min_similarity: Minimum similarity score threshold 0-1. If not provided,
383423
the server uses its default.
384424
account_ids: Optional account IDs (uses set_accounts() if not provided)
385-
fallback_to_local: If True, fall back to local BM25+TF-IDF search on API failure
425+
search: Search backend to use:
426+
- ``"auto"`` (default): try semantic search first, fall back to local
427+
BM25+TF-IDF if the API is unavailable.
428+
- ``"semantic"``: use only the semantic search API; raises
429+
``SemanticSearchError`` on failure.
430+
- ``"local"``: use only local BM25+TF-IDF search (no API call to the
431+
semantic search endpoint).
386432
387433
Returns:
388-
Tools collection with semantically matched tools from linked accounts
434+
Tools collection with matched tools from linked accounts
389435
390436
Raises:
391-
SemanticSearchError: If the API call fails and fallback_to_local is False
437+
SemanticSearchError: If the API call fails and search is ``"semantic"``
392438
393439
Examples:
394-
# Basic semantic search
440+
# Semantic search (default with local fallback)
395441
tools = toolset.search_tools("manage employee records", top_k=5)
396442
397-
# Filter by connector with minimum similarity
443+
# Explicit semantic search
444+
tools = toolset.search_tools("manage employees", search="semantic")
445+
446+
# Local BM25+TF-IDF search
447+
tools = toolset.search_tools("manage employees", search="local")
448+
449+
# Filter by connector
398450
tools = toolset.search_tools(
399451
"create time off request",
400452
connector="bamboohr",
401-
min_similarity=0.5
402-
)
403-
404-
# With account filtering
405-
tools = toolset.search_tools(
406-
"send message",
407-
account_ids=["acc-123"],
408-
top_k=3
453+
search="semantic",
409454
)
410455
"""
411456
all_tools = self.fetch_tools(account_ids=account_ids)
@@ -414,16 +459,22 @@ def search_tools(
414459
if not available_connectors:
415460
return Tools([])
416461

462+
# Local-only search — skip semantic API entirely
463+
if search == "local":
464+
return self._local_search(
465+
query, all_tools, connector=connector, top_k=top_k, min_similarity=min_similarity
466+
)
467+
417468
try:
418-
# Step 2: Determine which connectors to search
469+
# Determine which connectors to search
419470
if connector:
420471
connectors_to_search = {connector.lower()} & available_connectors
421472
if not connectors_to_search:
422473
return Tools([])
423474
else:
424475
connectors_to_search = available_connectors
425476

426-
# Step 3: Search each connector in parallel
477+
# Search each connector in parallel
427478
def _search_one(c: str) -> list[SemanticSearchResult]:
428479
resp = self.semantic_client.search(
429480
query=query, connector=c, top_k=top_k, min_similarity=min_similarity
@@ -445,15 +496,15 @@ def _search_one(c: str) -> list[SemanticSearchResult]:
445496
if not all_results and last_error is not None:
446497
raise last_error
447498

448-
# Step 4: Sort by score, apply top_k
499+
# Sort by score, apply top_k
449500
all_results.sort(key=lambda r: r.similarity_score, reverse=True)
450501
if top_k is not None:
451502
all_results = all_results[:top_k]
452503

453504
if not all_results:
454505
return Tools([])
455506

456-
# Step 5: Match back to fetched tool definitions
507+
# Match back to fetched tool definitions
457508
action_names = {_normalize_action_name(r.action_name) for r in all_results}
458509
matched_tools = [t for t in all_tools if t.name in action_names]
459510

@@ -464,28 +515,13 @@ def _search_one(c: str) -> list[SemanticSearchResult]:
464515
return Tools(matched_tools)
465516

466517
except SemanticSearchError as e:
467-
if not fallback_to_local:
518+
if search == "semantic":
468519
raise
469520

470521
logger.warning("Semantic search failed (%s), falling back to local BM25+TF-IDF search", e)
471-
472-
from stackone_ai.local_search import ToolIndex
473-
474-
index = ToolIndex(list(all_tools))
475-
results = index.search(
476-
query,
477-
limit=top_k if top_k is not None else 5,
478-
min_score=min_similarity if min_similarity is not None else 0.0,
522+
return self._local_search(
523+
query, all_tools, connector=connector, top_k=top_k, min_similarity=min_similarity
479524
)
480-
matched_names = [r.name for r in results]
481-
tool_map = {t.name: t for t in all_tools}
482-
filter_connectors = {connector.lower()} if connector else available_connectors
483-
matched_tools = [
484-
tool_map[name]
485-
for name in matched_names
486-
if name in tool_map and name.split("_")[0].lower() in filter_connectors
487-
]
488-
return Tools(matched_tools[:top_k] if top_k is not None else matched_tools)
489525

490526
def search_action_names(
491527
self,

0 commit comments

Comments
 (0)