|
| 1 | +""" |
| 2 | +Semantic Search Demo - Local BM25 vs Semantic Search |
| 3 | +
|
| 4 | +Demonstrates how semantic search understands natural language intent |
| 5 | +while local keyword search fails on synonyms and colloquial queries. |
| 6 | +
|
| 7 | +Run with local Lambda: |
| 8 | + cd ai-generation/apps/action_search && make run-local |
| 9 | + uv run python examples/demo_semantic_search.py --local |
| 10 | +
|
| 11 | +Run with production API: |
| 12 | + STACKONE_API_KEY=xxx uv run python examples/demo_semantic_search.py |
| 13 | +""" |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import argparse |
| 18 | +import os |
| 19 | +import time |
| 20 | +from dataclasses import dataclass |
| 21 | +from typing import Any |
| 22 | + |
| 23 | +import httpx |
| 24 | + |
| 25 | +from stackone_ai.semantic_search import ( |
| 26 | + SemanticSearchClient, |
| 27 | + SemanticSearchResponse, |
| 28 | + SemanticSearchResult, |
| 29 | +) |
| 30 | +from stackone_ai.utility_tools import ToolIndex |
| 31 | + |
| 32 | +# Local Lambda URL |
| 33 | +DEFAULT_LAMBDA_URL = "http://localhost:4513/2015-03-31/functions/function/invocations" |
| 34 | + |
| 35 | +# Demo queries - the strongest "wow" moments from benchmark results |
| 36 | +DEMO_QUERIES = [ |
| 37 | + { |
| 38 | + "query": "fire someone", |
| 39 | + "why": "Synonym: 'fire' = terminate employment", |
| 40 | + }, |
| 41 | + { |
| 42 | + "query": "ping the team", |
| 43 | + "why": "Intent: 'ping' = send a message", |
| 44 | + }, |
| 45 | + { |
| 46 | + "query": "file a new bug", |
| 47 | + "why": "Intent: 'file a bug' = create issue (not file operations)", |
| 48 | + }, |
| 49 | + { |
| 50 | + "query": "check my to-do list", |
| 51 | + "why": "Concept: 'to-do list' = list tasks", |
| 52 | + }, |
| 53 | + { |
| 54 | + "query": "show me everyone in the company", |
| 55 | + "why": "Synonym: 'everyone in company' = list employees", |
| 56 | + }, |
| 57 | + { |
| 58 | + "query": "turn down a job seeker", |
| 59 | + "why": "Synonym: 'turn down' = reject application", |
| 60 | + }, |
| 61 | + { |
| 62 | + "query": "approve PTO", |
| 63 | + "why": "Abbreviation: 'PTO' = paid time off request", |
| 64 | + }, |
| 65 | + { |
| 66 | + "query": "grab that spreadsheet", |
| 67 | + "why": "Colloquial: 'grab' = download file", |
| 68 | + }, |
| 69 | +] |
| 70 | + |
| 71 | + |
| 72 | +@dataclass |
| 73 | +class LightweightTool: |
| 74 | + """Minimal tool for BM25 indexing.""" |
| 75 | + |
| 76 | + name: str |
| 77 | + description: str |
| 78 | + |
| 79 | + |
| 80 | +class LocalLambdaClient: |
| 81 | + """Client for local action_search Lambda.""" |
| 82 | + |
| 83 | + def __init__(self, url: str = DEFAULT_LAMBDA_URL) -> None: |
| 84 | + self.url = url |
| 85 | + |
| 86 | + def search( |
| 87 | + self, |
| 88 | + query: str, |
| 89 | + connector: str | None = None, |
| 90 | + top_k: int = 5, |
| 91 | + ) -> SemanticSearchResponse: |
| 92 | + payload: dict[str, Any] = { |
| 93 | + "type": "search", |
| 94 | + "payload": {"query": query, "top_k": top_k}, |
| 95 | + } |
| 96 | + if connector: |
| 97 | + payload["payload"]["connector"] = connector |
| 98 | + |
| 99 | + resp = httpx.post(self.url, json=payload, timeout=30.0) |
| 100 | + resp.raise_for_status() |
| 101 | + data = resp.json() |
| 102 | + |
| 103 | + results = [ |
| 104 | + SemanticSearchResult( |
| 105 | + action_name=r.get("action_name", ""), |
| 106 | + connector_key=r.get("connector_key", ""), |
| 107 | + similarity_score=r.get("similarity_score", 0.0), |
| 108 | + label=r.get("label", ""), |
| 109 | + description=r.get("description", ""), |
| 110 | + ) |
| 111 | + for r in data.get("results", []) |
| 112 | + ] |
| 113 | + return SemanticSearchResponse( |
| 114 | + results=results, |
| 115 | + total_count=data.get("total_count", len(results)), |
| 116 | + query=data.get("query", query), |
| 117 | + ) |
| 118 | + |
| 119 | + def fetch_actions(self) -> list[LightweightTool]: |
| 120 | + """Fetch broad action catalog for BM25 index.""" |
| 121 | + seen: dict[str, LightweightTool] = {} |
| 122 | + for q in ["employee", "candidate", "contact", "task", "message", "file", "event", "deal"]: |
| 123 | + try: |
| 124 | + resp = httpx.post( |
| 125 | + self.url, |
| 126 | + json={"type": "search", "payload": {"query": q, "top_k": 500}}, |
| 127 | + timeout=30.0, |
| 128 | + ) |
| 129 | + for r in resp.json().get("results", []): |
| 130 | + name = r.get("action_name", "") |
| 131 | + if name and name not in seen: |
| 132 | + seen[name] = LightweightTool(name=name, description=r.get("description", "")) |
| 133 | + except Exception: |
| 134 | + continue |
| 135 | + return list(seen.values()) |
| 136 | + |
| 137 | + |
| 138 | +def shorten_name(name: str) -> str: |
| 139 | + """Shorten action name for display. |
| 140 | +
|
| 141 | + bamboohr_1.0.0_bamboohr_list_employees_global -> bamboohr: list_employees |
| 142 | + """ |
| 143 | + parts = name.split("_") |
| 144 | + # Find version segment (e.g., "1.0.0") and split around it |
| 145 | + version_idx = None |
| 146 | + for i, p in enumerate(parts): |
| 147 | + if "." in p and any(c.isdigit() for c in p): |
| 148 | + version_idx = i |
| 149 | + break |
| 150 | + |
| 151 | + if version_idx is not None: |
| 152 | + connector = parts[0] |
| 153 | + # Skip connector + version + repeated connector prefix |
| 154 | + action_parts = parts[version_idx + 1 :] |
| 155 | + # Remove leading connector name if repeated |
| 156 | + if action_parts and action_parts[0].lower().replace("-", "") == connector.lower().replace("-", ""): |
| 157 | + action_parts = action_parts[1:] |
| 158 | + # Remove trailing 'global' |
| 159 | + if action_parts and action_parts[-1] == "global": |
| 160 | + action_parts = action_parts[:-1] |
| 161 | + action = "_".join(action_parts) |
| 162 | + return f"{connector}: {action}" |
| 163 | + |
| 164 | + return name |
| 165 | + |
| 166 | + |
| 167 | +def print_header(text: str) -> None: |
| 168 | + print(f"\n{'=' * 70}") |
| 169 | + print(f" {text}") |
| 170 | + print(f"{'=' * 70}") |
| 171 | + |
| 172 | + |
| 173 | +def print_section(text: str) -> None: |
| 174 | + print(f"\n--- {text} ---\n") |
| 175 | + |
| 176 | + |
| 177 | +def run_demo(use_local: bool, lambda_url: str, api_key: str | None) -> None: |
| 178 | + # Step 1: Setup |
| 179 | + if use_local: |
| 180 | + client = LocalLambdaClient(url=lambda_url) |
| 181 | + semantic_search = client.search |
| 182 | + else: |
| 183 | + if not api_key: |
| 184 | + print("Error: STACKONE_API_KEY required for production mode") |
| 185 | + print("Use --local flag for local Lambda mode") |
| 186 | + exit(1) |
| 187 | + sem_client = SemanticSearchClient(api_key=api_key) |
| 188 | + semantic_search = sem_client.search |
| 189 | + client = None |
| 190 | + |
| 191 | + print_header("SEMANTIC SEARCH DEMO") |
| 192 | + print("\n Comparing Local BM25+TF-IDF vs Semantic Search") |
| 193 | + print(" across 5,144 actions from 200+ connectors\n") |
| 194 | + |
| 195 | + # Step 2: Build local BM25 index |
| 196 | + print(" Loading action catalog for local BM25 index...") |
| 197 | + if use_local: |
| 198 | + tools = client.fetch_actions() |
| 199 | + else: |
| 200 | + # For production mode, use semantic search to build catalog |
| 201 | + local_client = LocalLambdaClient(url=lambda_url) |
| 202 | + tools = local_client.fetch_actions() |
| 203 | + |
| 204 | + local_index = ToolIndex(tools) # type: ignore[arg-type] |
| 205 | + print(f" Indexed {len(tools)} actions\n") |
| 206 | + |
| 207 | + input(" Press Enter to start the demo...\n") |
| 208 | + |
| 209 | + # Step 3: Side-by-side comparison |
| 210 | + print_header("SIDE-BY-SIDE COMPARISON") |
| 211 | + |
| 212 | + local_hits = 0 |
| 213 | + semantic_hits = 0 |
| 214 | + |
| 215 | + for i, demo in enumerate(DEMO_QUERIES, 1): |
| 216 | + query = demo["query"] |
| 217 | + why = demo["why"] |
| 218 | + |
| 219 | + print(f"\n [{i}/{len(DEMO_QUERIES)}] Query: \"{query}\"") |
| 220 | + print(f" Why interesting: {why}") |
| 221 | + print() |
| 222 | + |
| 223 | + # Local search |
| 224 | + start = time.perf_counter() |
| 225 | + local_results = local_index.search(query, limit=3) |
| 226 | + local_ms = (time.perf_counter() - start) * 1000 |
| 227 | + local_names = [shorten_name(r.name) for r in local_results] |
| 228 | + |
| 229 | + # Semantic search |
| 230 | + start = time.perf_counter() |
| 231 | + sem_response = semantic_search(query=query, top_k=3) |
| 232 | + sem_ms = (time.perf_counter() - start) * 1000 |
| 233 | + sem_names = [shorten_name(r.action_name) for r in sem_response.results] |
| 234 | + sem_scores = [f"{r.similarity_score:.2f}" for r in sem_response.results] |
| 235 | + |
| 236 | + # Display |
| 237 | + w = 38 |
| 238 | + print(f" {'Local BM25 (keyword)':<{w}} | {'Semantic Search (AI)':<{w}}") |
| 239 | + print(f" {f'{local_ms:.1f}ms':<{w}} | {f'{sem_ms:.1f}ms':<{w}}") |
| 240 | + print(f" {'-' * w} | {'-' * w}") |
| 241 | + for j in range(min(3, max(len(local_names), len(sem_names)))): |
| 242 | + l_name = local_names[j] if j < len(local_names) else "" |
| 243 | + s_name = sem_names[j] if j < len(sem_names) else "" |
| 244 | + s_score = sem_scores[j] if j < len(sem_scores) else "" |
| 245 | + l_display = f" {l_name[:w]:<{w}}" |
| 246 | + s_display = f" {s_name[:w - 8]:<{w - 8}} ({s_score})" if s_name else "" |
| 247 | + print(f"{l_display} |{s_display}") |
| 248 | + |
| 249 | + input("\n Press Enter for next query...") |
| 250 | + |
| 251 | + # Step 4: Summary |
| 252 | + print_header("BENCHMARK RESULTS (94 evaluation tasks)") |
| 253 | + |
| 254 | + print(""" |
| 255 | + Method Hit@5 MRR Avg Latency |
| 256 | + ---------------------------------------------------------- |
| 257 | + Local BM25+TF-IDF 66.0% 0.538 1.2ms |
| 258 | + Semantic Search 76.6% 0.634 279.6ms |
| 259 | + ---------------------------------------------------------- |
| 260 | + Improvement +10.6% +0.096 |
| 261 | + """) |
| 262 | + |
| 263 | + # Step 5: Code examples |
| 264 | + print_header("DEVELOPER API") |
| 265 | + |
| 266 | + print(""" |
| 267 | + # 1. Direct semantic search |
| 268 | + from stackone_ai import StackOneToolSet |
| 269 | +
|
| 270 | + toolset = StackOneToolSet(api_key="xxx") |
| 271 | + tools = toolset.search_tools("fire someone", top_k=5) |
| 272 | + # Returns: terminate_employee, offboard_employee, ... |
| 273 | +
|
| 274 | +
|
| 275 | + # 2. Semantic search with connector filter |
| 276 | + tools = toolset.search_tools( |
| 277 | + "send a message", |
| 278 | + connector="slack", |
| 279 | + top_k=3, |
| 280 | + ) |
| 281 | + # Returns: slack_send_message, slack_create_conversation, ... |
| 282 | +
|
| 283 | +
|
| 284 | + # 3. MCP utility tool (for AI agents) |
| 285 | + tools = toolset.fetch_tools() |
| 286 | + utility = tools.utility_tools(use_semantic_search=True) |
| 287 | + # AI agent gets: tool_search (semantic-powered) + tool_execute |
| 288 | +
|
| 289 | +
|
| 290 | + # 4. Inspect results before fetching |
| 291 | + results = toolset.search_action_names("onboard new hire") |
| 292 | + for r in results: |
| 293 | + print(f"{r.action_name}: {r.similarity_score:.2f}") |
| 294 | + """) |
| 295 | + |
| 296 | + print_header("END OF DEMO") |
| 297 | + |
| 298 | + |
| 299 | +def main() -> None: |
| 300 | + parser = argparse.ArgumentParser(description="Semantic Search Demo") |
| 301 | + parser.add_argument("--local", action="store_true", help="Use local Lambda") |
| 302 | + parser.add_argument("--lambda-url", default=DEFAULT_LAMBDA_URL, help="Lambda URL") |
| 303 | + args = parser.parse_args() |
| 304 | + |
| 305 | + api_key = os.environ.get("STACKONE_API_KEY") |
| 306 | + run_demo(use_local=args.local, lambda_url=args.lambda_url, api_key=api_key) |
| 307 | + |
| 308 | + |
| 309 | +if __name__ == "__main__": |
| 310 | + main() |
0 commit comments