diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f43a647..849ba7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: run: | # Skip GPU tests as GitHub Actions runners don't have CUDA # To run GPU tests locally: pytest tests/ -v -m "gpu" - pytest tests/ -v --tb=short -m "not slow and not gpu and not integration" + python -m pytest tests/ -v --tb=short -m "not slow and not gpu and not integration" lint: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 55ae470..e10e7b3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,7 @@ dist/ */.DS_Store *.DS_Store +# Evaluation/Profiling ignores +*.prof +evaluation/sandbox/results/* +!evaluation/sandbox/results/.gitkeep diff --git a/contextpilot/context_index/compute_distance_cpu.py b/contextpilot/context_index/compute_distance_cpu.py index e3f10b0..621c1d8 100755 --- a/contextpilot/context_index/compute_distance_cpu.py +++ b/contextpilot/context_index/compute_distance_cpu.py @@ -272,7 +272,7 @@ def compute_distance_matrix_cpu_optimized(contexts: List[List[int]], start = time.time() chunk_ids, original_positions, lengths, offsets = prepare_contexts_for_cpu(contexts) prep_time = time.time() - start - print(f"✓ Prepared in {prep_time:.1f}s") + print(f"+ Prepared in {prep_time:.1f}s") # Generate batches of pair indices print(f"\nGenerating pair batches...") @@ -290,7 +290,7 @@ def compute_distance_matrix_cpu_optimized(contexts: List[List[int]], if current_batch: batches.append(current_batch) - print(f"✓ Generated {len(batches):,} batches") + print(f"+ Generated {len(batches):,} batches") # Prepare arguments for workers worker_args = [ @@ -306,13 +306,13 @@ def compute_distance_matrix_cpu_optimized(contexts: List[List[int]], start_time = time.time() processed = 0 - with Pool(num_workers) as pool: - for batch_results in pool.imap_unordered(compute_batch_worker, worker_args): + if num_workers == 1: + # Bypass multiprocessing Pool entirely to save initialization overhead + for args in worker_args: + batch_results = compute_batch_worker(args) for i, j, dist in batch_results: - # Convert (i, j) to condensed index condensed_idx = n * i - i * (i + 1) // 2 + j - i - 1 condensed_distances[condensed_idx] = dist - processed += 1 # Progress update @@ -326,6 +326,27 @@ def compute_distance_matrix_cpu_optimized(contexts: List[List[int]], f"Rate: {rate:,.0f} pairs/sec | " f"Elapsed: {elapsed:.1f}s | " f"ETA: {eta:.1f}s ({eta/60:.1f} min)") + else: + with Pool(num_workers) as pool: + for batch_results in pool.imap_unordered(compute_batch_worker, worker_args): + for i, j, dist in batch_results: + # Convert (i, j) to condensed index + condensed_idx = n * i - i * (i + 1) // 2 + j - i - 1 + condensed_distances[condensed_idx] = dist + + processed += 1 + + # Progress update + if processed % 100000 == 0 or processed == num_pairs: + elapsed = time.time() - start_time + rate = processed / elapsed if elapsed > 0 else 0 + eta = (num_pairs - processed) / rate if rate > 0 else 0 + progress_pct = processed / num_pairs * 100 + + print(f" {processed:,}/{num_pairs:,} ({progress_pct:.1f}%) | " + f"Rate: {rate:,.0f} pairs/sec | " + f"Elapsed: {elapsed:.1f}s | " + f"ETA: {eta:.1f}s ({eta/60:.1f} min)") compute_time = time.time() - start_time total_time = compute_time + prep_time diff --git a/contextpilot/server/live_index.py b/contextpilot/server/live_index.py index cbaeb31..ea6df7b 100644 --- a/contextpilot/server/live_index.py +++ b/contextpilot/server/live_index.py @@ -172,15 +172,15 @@ def build_and_schedule(self, contexts: List[List[int]], print("\n1. Building static index...") self.initial_result = self.fit_transform(contexts) - print(f" ✓ Built tree with {self.initial_result.stats['total_nodes']} nodes") - print(f" ✓ Leaf nodes: {self.initial_result.stats['leaf_nodes']}") + print(f" + Built tree with {self.initial_result.stats['total_nodes']} nodes") + print(f" + Leaf nodes: {self.initial_result.stats['leaf_nodes']}") # Step 2: Inter-context scheduling print("\n2. Scheduling contexts for optimal execution...") scheduled_reordered, scheduled_originals, final_mapping, groups = \ self.inter_scheduler.schedule_contexts(self.initial_result) - print(f" ✓ Created {len(groups)} execution groups") + print(f" + Created {len(groups)} execution groups") self.scheduled_result = { 'reordered_contexts': scheduled_reordered, @@ -197,8 +197,8 @@ def build_and_schedule(self, contexts: List[List[int]], num_input_contexts=len(contexts) ) - print(f" ✓ Initialized {len(self.metadata)} nodes with metadata") - print(f" ✓ Auto-assigned {len(request_id_mapping)} request IDs") + print(f" + Initialized {len(self.metadata)} nodes with metadata") + print(f" + Auto-assigned {len(request_id_mapping)} request IDs") # Add request_id mapping to result (dict and ordered list) self.scheduled_result['request_id_mapping'] = request_id_mapping @@ -208,7 +208,7 @@ def build_and_schedule(self, contexts: List[List[int]], self.is_live = True print("\n" + "=" * 80) - print("✓ INDEX IS NOW LIVE - Ready for dynamic operations") + print("+ INDEX IS NOW LIVE - Ready for dynamic operations") print("=" * 80 + "\n") return self.scheduled_result @@ -534,8 +534,8 @@ def build_incremental(self, contexts: List[List[int]], # No match - will build new index for these unmatched_contexts.append((i, context)) - print(f" ✓ Found {len(matched_contexts)} contexts with matches") - print(f" ✓ Found {len(unmatched_contexts)} contexts without matches") + print(f" + Found {len(matched_contexts)} contexts with matches") + print(f" + Found {len(unmatched_contexts)} contexts without matches") # Prepare result arrays (will fill in order) request_ids = [None] * len(contexts) @@ -585,7 +585,7 @@ def build_incremental(self, contexts: List[List[int]], ) temp_result = temp_index.fit_transform(unmatched_only) - print(f" ✓ Built temp index with {temp_result.stats['total_nodes']} nodes") + print(f" + Built temp index with {temp_result.stats['total_nodes']} nodes") # Step 4: Merge temp index into global index print("\n4. Merging temp index into global index...") @@ -606,16 +606,16 @@ def build_incremental(self, contexts: List[List[int]], context_info.append((orig_idx, merged_request_ids[i], merged_search_paths[i])) merged_count = len(unmatched_contexts) - print(f" ✓ Merged {merged_count} new subtrees under global root") + print(f" + Merged {merged_count} new subtrees under global root") # Step 5: Schedule execution order print("\n5. Scheduling execution order for cache reuse...") scheduled_order = self._schedule_incremental(context_info) groups = self._group_by_path_prefix(context_info) - print(f" ✓ Scheduled {len(scheduled_order)} contexts into {len(groups)} groups") + print(f" + Scheduled {len(scheduled_order)} contexts into {len(groups)} groups") print("\n" + "=" * 80) - print(f"✓ INCREMENTAL BUILD COMPLETE") + print(f"+ INCREMENTAL BUILD COMPLETE") print(f" Matched & inserted: {len(matched_contexts)}") print(f" Built & merged: {merged_count}") print("=" * 80 + "\n") @@ -906,15 +906,15 @@ def schedule_only(self, contexts: List[List[int]]) -> Dict: print("\n1. Building static index...") result = self.fit_transform(contexts) - print(f" ✓ Built tree with {result.stats['total_nodes']} nodes") - print(f" ✓ Leaf nodes: {result.stats['leaf_nodes']}") + print(f" + Built tree with {result.stats['total_nodes']} nodes") + print(f" + Leaf nodes: {result.stats['leaf_nodes']}") # Step 2: Inter-context scheduling print("\n2. Scheduling contexts for optimal execution...") scheduled_reordered, scheduled_originals, final_mapping, groups = \ self.inter_scheduler.schedule_contexts(result) - print(f" ✓ Created {len(groups)} execution groups") + print(f" + Created {len(groups)} execution groups") # Return results without going live (stateless) scheduled_result = { @@ -931,7 +931,7 @@ def schedule_only(self, contexts: List[List[int]]) -> Dict: } print("\n" + "=" * 80) - print("✓ BATCH SCHEDULED (Stateless - no cache tracking)") + print("+ BATCH SCHEDULED (Stateless - no cache tracking)") print("=" * 80 + "\n") return scheduled_result diff --git a/docs/guides/multi_turn.md b/docs/guides/multi_turn.md index 021582c..461d623 100644 --- a/docs/guides/multi_turn.md +++ b/docs/guides/multi_turn.md @@ -256,10 +256,10 @@ print(f"New docs: {result['new_docs']}") # [2] | Operation | `/reorder` | `/deduplicate` | |-----------|----------|----------------| -| Index build | ✓ | ✗ | -| Clustering | ✓ | ✗ | -| Search | ✓ | ✗ | -| Deduplication | ✓ | ✓ | +| Index build | + | ✗ | +| Clustering | + | ✗ | +| Search | + | ✗ | +| Deduplication | + | + | | **Latency** | ~50-200ms | ~1-5ms | For multi-turn conversations, Turn 2+ typically doesn't need index operations — just deduplication against conversation history. The `/deduplicate` endpoint is **10-100x faster**. diff --git a/evaluation/benchmarks/run_bigcodebench_elm.py b/evaluation/benchmarks/run_bigcodebench_elm.py new file mode 100644 index 0000000..92b9b7d --- /dev/null +++ b/evaluation/benchmarks/run_bigcodebench_elm.py @@ -0,0 +1,147 @@ +import asyncio +import json +import logging +import os +import re + +# pip install datasets +from datasets import load_dataset +from openai import AsyncOpenAI + +# Set PYTHONPATH in the environment before running +from refactored_plugins.skill_index import SkillAwareContextPlugin +from refactored_plugins.dedup import ContextDedupPlugin + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") +logger = logging.getLogger(__name__) + +# Create a registry of 10 dummy tools to trigger the Skill plugin +DUMMY_TOOL_REGISTRY = { + f"tool_{i}": { + "type": "function", + "function": { + "name": f"tool_{i}", + "description": f"Dummy tool number {i}" + } + } + for i in range(1, 11) +} + +async def process_task(task, skill_plugin, dedup_plugin, client, semaphore, output_file, turn_1_id): + """ + Processes a single BigCodeBench task through our ContextPilot plugins and ELM API. + """ + async with semaphore: + task_id = task.get("task_id", "unknown_task") + # BigCodeBench prompts are usually in 'complete_prompt' or 'instruction' + prompt = task.get("complete_prompt", task.get("instruction", "No prompt found.")) + + # Mock heavy agent request with redundant history and bloated tools + request = { + "user_id": "evaluator_1", + "parent_id": turn_1_id, + "_required_skills": ["tool_1", "tool_3", "tool_7"], # Require only 3 tools out of 10 + "messages": [ + {"role": "system", "content": "You are a senior python developer. Always wrap your code in ```python blocks."}, + {"role": "user", "content": "Please help me write some code."}, + {"role": "assistant", "content": "Of course! I can help you with that."}, + {"role": "user", "content": prompt} + ], + "tools": list(DUMMY_TOOL_REGISTRY.values()) + } + + # Pass through ContextPilot local plugins + optimized_request = await dedup_plugin.process(request) + optimized_request = await skill_plugin.process(optimized_request) + + # Prepare ELM API request (OpenAI-compatible) + api_kwargs = { + "model": "gpt-5.5", + "messages": optimized_request.get("messages", []) + } + if "tools" in optimized_request and optimized_request["tools"]: + api_kwargs["tools"] = optimized_request["tools"] + + try: + logger.info(f"Sending optimized task {task_id} to ELM API...") + response = await client.chat.completions.create(**api_kwargs) + response_content = response.choices[0].message.content + except Exception as e: + logger.error(f"API Error for {task_id}: {str(e)}") + response_content = "" + + # Extract code block using regex + extracted_code = "" + if response_content: + match = re.search(r"```python\s*(.*?)\s*```", response_content, re.DOTALL) + if match: + extracted_code = match.group(1).strip() + else: + # Fallback if the LLM didn't use the markdown block + extracted_code = response_content.strip() + + # Append result to JSONL + with open(output_file, "a", encoding="utf-8") as f: + f.write(json.dumps({"task_id": task_id, "solution": extracted_code}) + "\n") + + logger.info(f"Finished {task_id}") + +async def main(): + api_key = os.environ.get("OPENAI_API_KEY", "dummy-elm-key") + base_url = os.environ.get("BASE_URL", "https://api.openai.com/v1") + + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + skill_plugin = SkillAwareContextPlugin(tool_registry=DUMMY_TOOL_REGISTRY) + dedup_plugin = ContextDedupPlugin() + + # Pre-warm Dedup plugin with the initial messages to simulate conversation history + turn_1 = { + "user_id": "evaluator_1", + "messages": [ + {"role": "system", "content": "You are a senior python developer. Always wrap your code in ```python blocks."}, + {"role": "user", "content": "Please help me write some code."}, + {"role": "assistant", "content": "Of course! I can help you with that."} + ] + } + turn_1_res = await dedup_plugin.process(turn_1) + turn_1_id = turn_1_res.get("current_id") + + # Load BigCodeBench dataset + logger.info("Loading BigCodeBench dataset...") + try: + dataset = load_dataset("bigcode/bigcodebench", split="train") + except Exception as e: + logger.warning(f"Failed to load split='train'. Trying standard default split. Error: {e}") + # Fallback to the common default split format if 'train' split does not exist + try: + dataset = load_dataset("bigcode/bigcodebench", split="v0.1.2") + except Exception: + dataset = load_dataset("bigcode/bigcodebench", split="v0.1.0_240822") + + # Select first 5 tasks for a smoke test + tasks = list(dataset)[:5] + logger.info(f"Loaded {len(tasks)} tasks for smoke test.") + + output_file = os.path.join(os.path.dirname(__file__), "elm_samples.jsonl") + if os.path.exists(output_file): + os.remove(output_file) + + # Use a Semaphore with 1 to process sequentially and avoid early rate limits + semaphore = asyncio.Semaphore(1) + + coroutines = [process_task(t, skill_plugin, dedup_plugin, client, semaphore, output_file, turn_1_id) for t in tasks] + await asyncio.gather(*coroutines) + + print("\n=== Phase 2 Dataset Smoke Test Complete ===") + print(f"Results saved to {output_file}") + + print("\n=== Combined Cost-Savings Telemetry ===") + metrics = { + "skill_plugin_metrics": skill_plugin.get_plugin_metrics(), + "dedup_plugin_metrics": dedup_plugin.get_plugin_metrics() + } + print(json.dumps(metrics, indent=2)) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/benchmarks/run_elm_eval.py b/evaluation/benchmarks/run_elm_eval.py new file mode 100644 index 0000000..eb7530c --- /dev/null +++ b/evaluation/benchmarks/run_elm_eval.py @@ -0,0 +1,150 @@ +import asyncio +import os +import json +import logging +from typing import Any, Dict +from openai import AsyncOpenAI + +# Set PYTHONPATH in the environment before running if needed +from refactored_plugins.skill_index import SkillAwareContextPlugin +from refactored_plugins.dedup import ContextDedupPlugin + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") + +# Dummy tool registry for SkillAwareContextPlugin to filter from +DUMMY_TOOL_REGISTRY = { + "python_repl": { + "type": "function", + "function": { + "name": "python_repl", + "description": "Executes Python code in a sandboxed environment" + } + }, + "web_search": { + "type": "function", + "function": { + "name": "web_search", + "description": "Searches the web for up-to-date documentation" + } + }, + "file_writer": { + "type": "function", + "function": { + "name": "file_writer", + "description": "Writes code to a file" + } + } +} + +async def evaluate_task(task_data: Dict[str, Any], + skill_plugin: SkillAwareContextPlugin, + dedup_plugin: ContextDedupPlugin, + client: AsyncOpenAI) -> Dict[str, Any]: + """ + Optimizes a request using Phase 1 plugins and sends it to the ELM API. + """ + # 1. ContextDedupPlugin (strip redundant history) + optimized_data = await dedup_plugin.process(task_data) + + # 2. SkillAwareContextPlugin (strip redundant tools based on _required_skills) + optimized_data = await skill_plugin.process(optimized_data) + + # 3. Call ELM API via OpenAI client + api_kwargs = { + "model": "gpt-5.5", + "messages": optimized_data.get("messages", []) + } + + # Only pass tools if the plugin injected any + if "tools" in optimized_data and optimized_data["tools"]: + api_kwargs["tools"] = optimized_data["tools"] + + try: + response = await client.chat.completions.create(**api_kwargs) + message = response.choices[0].message + if message.tool_calls: + # If the model decides to invoke tools, format the tool calls details + calls = [] + for tc in message.tool_calls: + # Safely parse arguments if present + args = {} + if tc.function.arguments: + try: + args = json.loads(tc.function.arguments) + except Exception: + args = tc.function.arguments + calls.append({ + "id": tc.id, + "name": tc.function.name, + "arguments": args + }) + response_content = f"Tool Calls Triggered:\n{json.dumps(calls, indent=2)}" + else: + response_content = message.content + except Exception as e: + response_content = f"API Error: {str(e)}" + response = None + + # 4. Gather Telemetry + telemetry = { + "skill_plugin_metrics": skill_plugin.get_plugin_metrics(), + "dedup_plugin_metrics": dedup_plugin.get_plugin_metrics() + } + + return { + "response": response, + "response_content": response_content, + "telemetry": telemetry, + "optimized_payload": optimized_data + } + +async def main(): + api_key = os.environ.get("OPENAI_API_KEY", "dummy-elm-key") + base_url = os.environ.get("BASE_URL", "https://api.openai.com/v1") + + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + skill_plugin = SkillAwareContextPlugin(tool_registry=DUMMY_TOOL_REGISTRY) + dedup_plugin = ContextDedupPlugin() + + print("=== Phase 2: ELM API Evaluator ===") + + # To demonstrate deduplication savings, we first run a mock Turn 1 to prime the history + turn_1 = { + "user_id": "evaluator_1", + "messages": [ + {"role": "system", "content": "You are an expert Python engineer taking the BigCodeBench evaluation."}, + {"role": "user", "content": "Write a script to compute the fast inverse square root."}, + {"role": "assistant", "content": "Here is the implementation: `def q_rsqrt(number): ...`"} + ] + } + turn_1_res = await dedup_plugin.process(turn_1) + parent_id = turn_1_res.get("current_id") + + # Mock BigCodeBench Turn 2 Task (includes redundant history from Turn 1) + mock_task = { + "user_id": "evaluator_1", + "parent_id": parent_id, + "_required_skills": ["python_repl", "file_writer"], + "messages": [ + {"role": "system", "content": "You are an expert Python engineer taking the BigCodeBench evaluation."}, + {"role": "user", "content": "Write a script to compute the fast inverse square root."}, + {"role": "assistant", "content": "Here is the implementation: `def q_rsqrt(number): ...`"}, + {"role": "user", "content": "Now, execute this code in the python_repl to verify it handles float(0.15625) correctly."} + ] + } + + print(f"\n[1] Starting API Request Evaluation...") + result = await evaluate_task(mock_task, skill_plugin, dedup_plugin, client) + + print("\n[2] Optimized Payload Sent to ELM API:") + print(json.dumps(result["optimized_payload"], indent=2)) + + print("\n[3] ELM API Response:") + print(result["response_content"]) + + print("\n[4] Cost-Savings Telemetry:") + print(json.dumps(result["telemetry"], indent=2)) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/benchmarks/run_sandbox_eval.ps1 b/evaluation/benchmarks/run_sandbox_eval.ps1 new file mode 100644 index 0000000..a85236f --- /dev/null +++ b/evaluation/benchmarks/run_sandbox_eval.ps1 @@ -0,0 +1,29 @@ +$source = "D:\AI4Coding\ContextPilot\evaluation\benchmarks\elm_samples.jsonl" +$destDir = "D:\AI4Coding\ContextPilot\evaluation\sandbox\results" +$destination = Join-Path $destDir "elm_samples.jsonl" + +Write-Host "Preparing Sandbox Environment..." + +# Ensure the results directory exists +if (-not (Test-Path -Path $destDir)) { + New-Item -ItemType Directory -Force -Path $destDir | Out-Null +} + +# Copy the samples to the volume-mounted folder +Copy-Item -Path $source -Destination $destination -Force +Write-Host "Successfully copied elm_samples.jsonl to the Docker volume mount ($destDir)." + +Write-Host "`nChecking if bigcodebench-sandbox is running..." +$containerStatus = docker inspect -f '{{.State.Running}}' bigcodebench-sandbox 2>$null +if ($containerStatus -ne "true") { + Write-Host "Container is not running. Starting it now..." + docker start bigcodebench-sandbox +} else { + Write-Host "Container is already running." +} + +Write-Host "`nExecuting BigCodeBench Evaluation securely inside the Docker Sandbox..." +# Run the evaluation command inside the existing running container +docker exec bigcodebench-sandbox bigcodebench.evaluate --samples /app/results/elm_samples.jsonl + +Write-Host "`nEvaluation Complete!" diff --git a/evaluation/benchmarks/run_sandbox_eval.sh b/evaluation/benchmarks/run_sandbox_eval.sh new file mode 100644 index 0000000..49461c6 --- /dev/null +++ b/evaluation/benchmarks/run_sandbox_eval.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Define paths (Git Bash / WSL compatible Windows paths or native Windows paths work with cp) +SOURCE="D:/AI4Coding/ContextPilot/evaluation/benchmarks/elm_samples.jsonl" +DEST_DIR="D:/AI4Coding/ContextPilot/evaluation/sandbox/results" +DEST="$DEST_DIR/elm_samples.jsonl" + +echo "Preparing Sandbox Environment..." + +# Ensure the results directory exists +mkdir -p "$DEST_DIR" + +# Copy the samples to the volume-mounted folder +cp "$SOURCE" "$DEST" +echo "Successfully copied elm_samples.jsonl to the Docker volume mount ($DEST_DIR)." + +echo -e "\nChecking if bigcodebench-sandbox is running..." +if [ "$(docker inspect -f '{{.State.Running}}' bigcodebench-sandbox 2>/dev/null)" != "true" ]; then + echo "Container is not running. Starting it now..." + docker start bigcodebench-sandbox +else + echo "Container is already running." +fi + +echo -e "\nExecuting BigCodeBench Evaluation securely inside the Docker Sandbox..." +# Run the evaluation command inside the existing running container +docker exec bigcodebench-sandbox bigcodebench.evaluate --samples /app/results/elm_samples.jsonl + +echo -e "\nEvaluation Complete!" diff --git a/evaluation/core_merge/mock_proxy.py b/evaluation/core_merge/mock_proxy.py new file mode 100644 index 0000000..56be93c --- /dev/null +++ b/evaluation/core_merge/mock_proxy.py @@ -0,0 +1,150 @@ +import asyncio +import json +import logging +import sys +from typing import Any, Dict, List + +# Import all 4 plugins from refactored_plugins +from refactored_plugins.skill_index import SkillAwareContextPlugin +from refactored_plugins.dedup import ContextDedupPlugin +from refactored_plugins.reorder import ContextReorderPlugin +from refactored_plugins.kv_lookup import KVCacheLookupPlugin + +# Configure simple logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") + +class MockProxy: + """ + Simulates the core Token Proxy Middleware pipeline by chaining all Phase 1 plugins. + """ + def __init__(self): + # 1. SkillAwareContextPlugin (dummy tool registry) + dummy_tool_registry = { + "math": { + "type": "function", + "function": { + "name": "math_tool", + "description": "Performs mathematical calculations" + } + }, + "weather": { + "type": "function", + "function": { + "name": "weather_tool", + "description": "Gets the current weather" + } + } + } + self.skill_plugin = SkillAwareContextPlugin(tool_registry=dummy_tool_registry) + + # 2. ContextDedupPlugin + self.dedup_plugin = ContextDedupPlugin() + + # 3. ContextReorderPlugin + # We specify use_gpu=False for the mock test to avoid requiring torch/CUDA. + self.reorder_plugin = ContextReorderPlugin(use_gpu=False) + + # 4. KVCacheLookupPlugin (dummy ZMQ endpoints) + dummy_endpoints = ["tcp://localhost:5557", "tcp://localhost:5558"] + self.kv_lookup_plugin = KVCacheLookupPlugin(endpoints=dummy_endpoints) + + async def process_batch(self, request_batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Executes the exact chain of execution for a batch of requests. + """ + # 1. SkillAwareContextPlugin (processes individual requests) + batch_after_skill = [] + for req in request_batch: + res = await self.skill_plugin.process(req) + batch_after_skill.append(res) + + # 2. ContextDedupPlugin (processes individual requests) + batch_after_dedup = [] + user_to_last_id = {} + for req in batch_after_skill: + user_id = req.get("user_id") + if user_id in user_to_last_id: + req["parent_id"] = user_to_last_id[user_id] + + res = await self.dedup_plugin.process(req) + + if user_id and "current_id" in res: + user_to_last_id[user_id] = res["current_id"] + + batch_after_dedup.append(res) + + # 3. ContextReorderPlugin (processes the entire batch) + batch_after_reorder = await self.reorder_plugin.process(batch_after_dedup) + + # 4. KVCacheLookupPlugin (processes individual requests) + final_batch = [] + for req in batch_after_reorder: + res = await self.kv_lookup_plugin.process(req) + final_batch.append(res) + + return final_batch + + def get_all_metrics(self) -> Dict[str, Dict[str, float]]: + """Aggregates telemetry from all plugins.""" + return { + "skill_plugin": self.skill_plugin.get_plugin_metrics(), + "dedup_plugin": self.dedup_plugin.get_plugin_metrics(), + "reorder_plugin": self.reorder_plugin.get_plugin_metrics(), + "kv_lookup_plugin": self.kv_lookup_plugin.get_plugin_metrics(), + } + +async def main(): + # Create a complex mock batch of 3 OpenAI requests simulating: + # 1. Multi-turn conversation (redundant history) + # 2. Dynamic Tool/Skill filtering + # 3. Overlapping system prompts for Prefix Sharing + mock_batch = [ + { + "user_id": "user_1", + "_required_skills": ["math"], + "messages": [ + {"role": "system", "content": "You are an AI assistant. Answer accurately and be concise."}, + {"role": "user", "content": "Hello! I am preparing for my exams."}, + {"role": "assistant", "content": "Hello! I can help you study. What subject?"}, + {"role": "user", "content": "Calculate 15 * 32 for my math homework."} + ] + }, + { + "user_id": "user_1", + "_required_skills": ["weather"], + "messages": [ + {"role": "system", "content": "You are an AI assistant. Answer accurately and be concise."}, + {"role": "user", "content": "Hello! I am preparing for my exams."}, + {"role": "assistant", "content": "Hello! I can help you study. What subject?"}, + {"role": "user", "content": "Actually, skip studying. What is the weather outside?"} + ] + }, + { + "user_id": "user_2", + "_required_skills": ["math", "weather", "invalid_skill"], + "messages": [ + {"role": "system", "content": "You are an AI assistant. Answer accurately and be concise."}, + {"role": "user", "content": "I need help with math and weather!"} + ] + } + ] + + print("=== Phase 1: Core Merge - Mock Proxy Pipeline ===\n") + print("[1] Initializing Mock Proxy & 4 Plugins...") + proxy = MockProxy() + + print("\n[2] Executing process_batch()...") + optimized_batch = await proxy.process_batch(mock_batch) + + print("\n=== Optimized Batch Output ===") + print(json.dumps(optimized_batch, indent=2)) + + print("\n=== Telemetry Metrics ===") + metrics = proxy.get_all_metrics() + print(json.dumps(metrics, indent=2)) + + # Force exit to cleanly terminate lingering background ZMQ tasks created by KVCacheLookupPlugin + sys.exit(0) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/profiling/dummy_agent.py b/evaluation/profiling/dummy_agent.py new file mode 100644 index 0000000..07281b6 --- /dev/null +++ b/evaluation/profiling/dummy_agent.py @@ -0,0 +1,55 @@ +import json +import re +import time +import random + +def heavy_json_ops(): + """Simulates API payload handling (serialization/deserialization).""" + # Create a reasonably complex nested structure + payload = { + "metadata": {"version": "1.0", "timestamp": time.time()}, + "agents": [ + {"id": i, "name": f"Agent_{i}", "history": ["observation" * 10 for _ in range(20)]} + for i in range(50) + ], + "configuration": {f"key_{i}": "value" * 50 for i in range(100)} + } + + # Burn CPU parsing and serializing + for _ in range(1000): # Reduced from 10k to keep dummy run time reasonable (approx few seconds) + s = json.dumps(payload) + _ = json.loads(s) + +def heavy_regex_ops(): + """Simulates prompt formatting and log parsing.""" + # Large block of text + base_text = "The quick brown fox jumps over the lazy dog. " * 500 + + # Search for patterns and manipulate strings + patterns = [r"\b\w{5}\b", r"fox.*?dog", r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+"] + + for _ in range(50): + # Regex search + for p in patterns: + re.findall(p, base_text) + + # String concatenation (O(n^2) behavior in some languages, but Python is optimized) + # Still burns time for large N + _ = "".join([base_text[i:i+10] for i in range(0, len(base_text), 2)]) + +def agent_turn(turn_id): + print(f"Executing Agent Turn {turn_id}...") + heavy_json_ops() + heavy_regex_ops() + +def main(): + start_time = time.time() + # Simulate 50 agent turns + for i in range(50): + agent_turn(i) + + end_time = time.time() + print(f"\nSimulation complete in {end_time - start_time:.2f} seconds.") + +if __name__ == "__main__": + main() diff --git a/evaluation/profiling/run_profile.ps1 b/evaluation/profiling/run_profile.ps1 new file mode 100644 index 0000000..cd6fec9 --- /dev/null +++ b/evaluation/profiling/run_profile.ps1 @@ -0,0 +1,11 @@ +# Ensure dependencies are installed +Write-Host "Installing snakeviz..." -ForegroundColor Cyan +pip install snakeviz --quiet + +# Run the script with cProfile +Write-Host "Running dummy_agent.py with cProfile..." -ForegroundColor Cyan +python -m cProfile -o agent_profile.prof dummy_agent.py + +# Visualize the results +Write-Host "Launching snakeviz for visualization..." -ForegroundColor Cyan +snakeviz agent_profile.prof diff --git a/evaluation/profiling/run_profile.sh b/evaluation/profiling/run_profile.sh new file mode 100644 index 0000000..bb10e44 --- /dev/null +++ b/evaluation/profiling/run_profile.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Ensure dependencies are installed +echo "Installing snakeviz..." +pip install snakeviz --quiet + +# Run the script with cProfile +# -o agent_profile.prof: outputs the binary profile data to a file +echo "Running dummy_agent.py with cProfile..." +python -m cProfile -o agent_profile.prof dummy_agent.py + +# Visualize the results +# This will open a browser tab with a flame graph/icicle graph +echo "Launching snakeviz for visualization..." +snakeviz agent_profile.prof diff --git a/evaluation/sandbox/Dockerfile b/evaluation/sandbox/Dockerfile new file mode 100644 index 0000000..e8ab197 --- /dev/null +++ b/evaluation/sandbox/Dockerfile @@ -0,0 +1,25 @@ +# Use a slim Python 3.10 base image +FROM python:3.10-slim + +# Install system dependencies that might be needed for common code execution +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Create a non-root user for security +RUN useradd -m -s /bin/bash evaluser + +# Set working directory +WORKDIR /app + +# Install bigcodebench +RUN pip install --no-cache-dir bigcodebench transformers torch + +# Create results directory and set ownership +RUN mkdir -p /app/results && chown -R evaluser:evaluser /app + +# Switch to the non-root user +USER evaluser + +# Set entrypoint to bash for interactive testing/persistence +ENTRYPOINT ["/bin/bash"] diff --git a/evaluation/sandbox/docker-compose.yml b/evaluation/sandbox/docker-compose.yml new file mode 100644 index 0000000..fef038b --- /dev/null +++ b/evaluation/sandbox/docker-compose.yml @@ -0,0 +1,21 @@ +version: '3.8' + +services: + bigcodebench-sandbox: + build: + context: . + dockerfile: Dockerfile + container_name: bigcodebench-sandbox + # Security: Disable network access to prevent data exfiltration or external attacks + network_mode: "none" + deploy: + resources: + limits: + cpus: '1.0' + memory: 2G + volumes: + # Mount results directory to extract benchmark data + - ./results:/app/results + working_dir: /app + stdin_open: true # docker run -i + tty: true # docker run -t diff --git a/evaluation/sandbox/results/.gitkeep b/evaluation/sandbox/results/.gitkeep new file mode 100644 index 0000000..98770b1 --- /dev/null +++ b/evaluation/sandbox/results/.gitkeep @@ -0,0 +1 @@ +# Keep this directory tracked by Git, so the Docker mount target exists. diff --git a/evaluation/zmq_prototype/dummy_sglang_pub.py b/evaluation/zmq_prototype/dummy_sglang_pub.py new file mode 100644 index 0000000..ad7ba1e --- /dev/null +++ b/evaluation/zmq_prototype/dummy_sglang_pub.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +dummy_sglang_pub.py + +A ZeroMQ Publisher that simulates SGLang's internal Radix Tree KV Cache event stream. +It binds to tcp://*:5557 and randomly publishes BlockStored and BlockRemoved events +in a logically consistent manner (maintaining a simulated tree structure) every second. + +Author: Senior AI Infrastructure Software Engineer +Project: Middleware Token Proxy Middleware - WP2 ZMQ Prototype +""" + +import zmq +import json +import time +import random +import sys + +# ANSI Escape Sequences for beautiful terminal output +class Colors: + HEADER = '\033[95m' + BLUE = '\033[94m' + GREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + CYAN = '\033[96m' + +def log_info(msg): + print(f"{Colors.BLUE}[INFO]{Colors.ENDC} {msg}") + +def log_success(msg): + print(f"{Colors.GREEN}[STORED]{Colors.ENDC} {msg}") + +def log_warning(msg): + print(f"{Colors.WARNING}[REMOVED]{Colors.ENDC} {msg}") + +def main(): + # Setup ZMQ Context and PUB Socket + context = zmq.Context() + publisher = context.socket(zmq.PUB) + + # Configure high-water mark to prevent memory bloat + publisher.set_hwm(1000) + + bind_address = "tcp://*:5557" + try: + publisher.bind(bind_address) + except Exception as e: + print(f"{Colors.FAIL}{Colors.BOLD}Failed to bind to {bind_address}: {e}{Colors.ENDC}", file=sys.stderr) + sys.exit(1) + + print(f"{Colors.HEADER}{Colors.BOLD}" + "="*60 + f"{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}SGLang Dummy KV Cache Event Publisher Started{Colors.ENDC}") + print(f"{Colors.CYAN}Binding Address:{Colors.ENDC} {bind_address}") + print(f"{Colors.CYAN}Simulation Rate:{Colors.ENDC} 1 event / sec") + print(f"{Colors.HEADER}{Colors.BOLD}" + "="*60 + f"{Colors.ENDC}") + print("Press Ctrl+C to terminate the publisher gracefully.\n") + + # Local state to ensure logically consistent events + # Maps block_hash -> parent_block_hash + active_blocks = {} + + # Track sequence number to simulate SGLang's monotonic event sequence + sequence_number = 0 + + try: + while True: + # We want to maintain a reasonable number of blocks in the cache (e.g., 3 to 15) + num_blocks = len(active_blocks) + + # Decide whether to store a new block or remove an existing one + # If cache is empty, we must store. If cache is full (> 15), we prefer to remove. + if num_blocks == 0: + action = "store" + elif num_blocks > 12: + action = "remove" if random.random() < 0.7 else "store" + else: + # 65% chance to store (grow), 35% chance to remove (evict) + action = "store" if random.random() < 0.65 else "remove" + + sequence_number += 1 + + if action == "store": + # Generate a unique block hash (simulating a 64-bit integer hash) + block_hash = random.randint(10000000, 99999999) + while block_hash in active_blocks: + block_hash = random.randint(10000000, 99999999) + + # Determine parent block hash (simulate tree branching) + parent_block_hash = None + if active_blocks and random.random() < 0.7: + # Pick an existing block as parent to create a hierarchy + parent_block_hash = random.choice(list(active_blocks.keys())) + + # Generate a list of random token IDs (representing the prefix contents of this block) + # SGLang usually has a block size / page size (e.g. 16 tokens) + block_size = 16 + token_ids = [random.randint(1, 50000) for _ in range(block_size)] + + # Construct BlockStored event payload + event = { + "type": "BlockStored", + "sequence": sequence_number, + "block_hash": block_hash, + "parent_block_hash": parent_block_hash, + "token_ids": token_ids, + "block_size": block_size, + "medium": "GPU" + } + + # Update local tracking + active_blocks[block_hash] = parent_block_hash + + # Publish event + event_str = json.dumps(event) + publisher.send_string(event_str) + + parent_str = f"0x{parent_block_hash:08x}" if parent_block_hash else "None" + log_success(f"Block: 0x{block_hash:08x} | Parent: {parent_str} | Tokens: {token_ids[:3]}... ({block_size} tokens)") + + else: # remove + # Pick a block to remove + # To simulate realistic tree eviction, we should ideally evict leaf blocks. + # Let's find leaf blocks (blocks that are not parents of any other active blocks) + all_parents = set(active_blocks.values()) + leaves = [b for b in active_blocks if b not in all_parents] + + # Fallback to any active block if no leaves are easily found (should always find at least one) + block_to_remove = random.choice(leaves) if leaves else random.choice(list(active_blocks.keys())) + + # Construct BlockRemoved event payload + event = { + "type": "BlockRemoved", + "sequence": sequence_number, + "block_hash": block_to_remove, + "medium": "GPU" + } + + # Update local tracking + del active_blocks[block_to_remove] + + # Publish event + event_str = json.dumps(event) + publisher.send_string(event_str) + + log_warning(f"Block: 0x{block_to_remove:08x}") + + time.sleep(1.0) + + except KeyboardInterrupt: + print(f"\n{Colors.WARNING}Shutting down publisher...{Colors.ENDC}") + finally: + publisher.close() + context.term() + print(f"{Colors.GREEN}Publisher terminated cleanly.{Colors.ENDC}") + +if __name__ == "__main__": + main() diff --git a/evaluation/zmq_prototype/shadow_tree_sub.py b/evaluation/zmq_prototype/shadow_tree_sub.py new file mode 100644 index 0000000..e16c514 --- /dev/null +++ b/evaluation/zmq_prototype/shadow_tree_sub.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +""" +shadow_tree_sub.py + +A ZeroMQ Subscriber that connects to tcp://localhost:5557, subscribes to all events, +and maintains a local "shadow_cache" dictionary reflecting SGLang's Radix Tree cache. +When events are received, it updates the shadow cache and prints a beautiful ASCII +visualization of the live tree hierarchy. + +Author: Senior AI Infrastructure Software Engineer +Project: Middleware Token Proxy Middleware - WP2 ZMQ Prototype +""" + +import zmq +import json +import sys +from collections import defaultdict + +# ANSI Escape Sequences for beautiful terminal output +class Colors: + HEADER = '\033[95m' + BLUE = '\033[94m' + GREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + CYAN = '\033[96m' + DARK_GRAY = '\033[90m' + +def print_radix_tree(shadow_cache): + """ + Reconstructs and prints the Radix Tree hierarchy from the flat shadow_cache. + """ + if not shadow_cache: + print(f"{Colors.DARK_GRAY} [Cache is currently empty]{Colors.ENDC}") + return + + # Build adjacency list: parent -> list of child block hashes + children = defaultdict(list) + roots = [] + + for block_hash, block_data in shadow_cache.items(): + parent = block_data.get("parent_block_hash") + # A block is a root if its parent is None OR if its parent is not in the shadow cache + if parent is None or parent not in shadow_cache: + roots.append(block_hash) + else: + children[parent].append(block_hash) + + # Sort roots and children by block hash for deterministic display + roots.sort() + for parent in children: + children[parent].sort() + + total_tokens = sum(len(b.get("token_ids", [])) for b in shadow_cache.values()) + print(f"\n{Colors.BOLD}Shadow Radix Tree Cache State:{Colors.ENDC}") + print(f" ├─ {Colors.CYAN}Total Blocks:{Colors.ENDC} {len(shadow_cache)}") + print(f" └─ {Colors.CYAN}Total Tokens:{Colors.ENDC} {total_tokens}") + print(f"{Colors.DARK_GRAY}Tree Visualization:{Colors.ENDC}") + + def dfs(node_hash, prefix="", is_last=True): + node_data = shadow_cache[node_hash] + token_ids = node_data.get("token_ids", []) + + # Display shortened preview of token IDs + if len(token_ids) > 6: + token_preview = f"[{', '.join(map(str, token_ids[:3]))}, ..., {', '.join(map(str, token_ids[-3:]))}]" + else: + token_preview = str(token_ids) + + connector = "└── " if is_last else "├── " + node_label = f"{Colors.GREEN}Block 0x{node_hash:08x}{Colors.ENDC}" + details = f"{Colors.DARK_GRAY}(tokens: {len(token_ids)}, val: {token_preview}){Colors.ENDC}" + + print(f"{prefix}{connector}{node_label} {details}") + + # Recurse children + node_children = children[node_hash] + child_count = len(node_children) + for i, child_hash in enumerate(node_children): + new_prefix = prefix + (" " if is_last else "│ ") + dfs(child_hash, new_prefix, is_last=(i == child_count - 1)) + + # Print the tree starting from each root node + for idx, root_hash in enumerate(roots): + dfs(root_hash, prefix=" ", is_last=(idx == len(roots) - 1)) + print() + +def main(): + # Setup ZMQ Context and SUB Socket + context = zmq.Context() + subscriber = context.socket(zmq.SUB) + + connect_address = "tcp://localhost:5557" + try: + subscriber.connect(connect_address) + except Exception as e: + print(f"{Colors.FAIL}{Colors.BOLD}Failed to connect to {connect_address}: {e}{Colors.ENDC}", file=sys.stderr) + sys.exit(1) + + # Subscribe to all events (empty prefix string) + subscriber.setsockopt_string(zmq.SUBSCRIBE, "") + + print(f"{Colors.HEADER}{Colors.BOLD}" + "="*60 + f"{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}SGLang Shadow KV Cache Tree Subscriber Started{Colors.ENDC}") + print(f"{Colors.CYAN}Connecting to:{Colors.ENDC} {connect_address}") + print(f"{Colors.CYAN}Subscription topic:{Colors.ENDC} [ALL EVENTS]") + print(f"{Colors.HEADER}{Colors.BOLD}" + "="*60 + f"{Colors.ENDC}") + print("Waiting for SGLang publisher events... (Press Ctrl+C to terminate)\n") + + # In-memory dictionary tracking cache state: block_hash -> metadata dict + shadow_cache = {} + + try: + while True: + # Receive event string + event_str = subscriber.recv_string() + + try: + event = json.loads(event_str) + except json.JSONDecodeError as je: + print(f"{Colors.FAIL}[ERROR] Failed to parse JSON event: {je}{Colors.ENDC}", file=sys.stderr) + continue + + event_type = event.get("type") + seq = event.get("sequence", 0) + block_hash = event.get("block_hash") + + if not event_type or block_hash is None: + print(f"{Colors.WARNING}[WARN] Received invalid event structure: {event}{Colors.ENDC}") + continue + + print(f"{Colors.BLUE}[Seq: {seq:03d}]{Colors.ENDC} Received {Colors.BOLD}{event_type}{Colors.ENDC} for Block {Colors.BOLD}0x{block_hash:08x}{Colors.ENDC}") + + if event_type == "BlockStored": + # Store the block details in the local shadow tree cache + shadow_cache[block_hash] = { + "parent_block_hash": event.get("parent_block_hash"), + "token_ids": event.get("token_ids", []), + "block_size": event.get("block_size", 0), + "medium": event.get("medium", "GPU"), + "seq": seq + } + elif event_type == "BlockRemoved": + # Remove the block from the local shadow tree cache + if block_hash in shadow_cache: + del shadow_cache[block_hash] + else: + print(f" {Colors.WARNING}* Block 0x{block_hash:08x} was not found in local shadow cache, skipping deletion *{Colors.ENDC}") + else: + print(f" {Colors.WARNING}* Unknown event type: {event_type} *{Colors.ENDC}") + + # Reconstruct and display the current cache tree structure + print_radix_tree(shadow_cache) + print("-" * 60) + + except KeyboardInterrupt: + print(f"\n{Colors.WARNING}Shutting down subscriber...{Colors.ENDC}") + finally: + subscriber.close() + context.term() + print(f"{Colors.GREEN}Subscriber terminated cleanly.{Colors.ENDC}") + +if __name__ == "__main__": + main() diff --git a/examples/http_server_example.py b/examples/http_server_example.py index 6959429..c5c2fee 100644 --- a/examples/http_server_example.py +++ b/examples/http_server_example.py @@ -32,7 +32,7 @@ def check_server(): try: response = requests.get(f"{BASE_URL}/health", timeout=2.0) health = response.json() - print(f"✓ Server status: {health['status']}") + print(f"+ Server status: {health['status']}") return health except Exception as e: print(f"✗ Server not running: {e}") @@ -81,7 +81,7 @@ def build_index(): ) result = response.json() - print(f"✓ Index built: {len(result['request_ids'])} request IDs") + print(f"+ Index built: {len(result['request_ids'])} request IDs") print(f" Reordered contexts for optimal cache sharing") return result @@ -140,7 +140,7 @@ def stateless_schedule(): ) result = response.json() - print(f"✓ Reordered into {len(result['groups'])} groups") + print(f"+ Reordered into {len(result['groups'])} groups") return result @@ -219,7 +219,7 @@ def main(): print() print("=" * 70) - print("✓ Example complete!") + print("+ Example complete!") print("=" * 70) diff --git a/examples/stateless_batch_example.py b/examples/stateless_batch_example.py index 3d62f07..2cf8652 100644 --- a/examples/stateless_batch_example.py +++ b/examples/stateless_batch_example.py @@ -44,7 +44,7 @@ def example_with_client(): result = client.reorder_raw(contexts) if result: - print(f"\n✓ Batch reordered successfully!") + print(f"\n+ Batch reordered successfully!") print(f" Mode: {result.get('mode', 'stateless')}") print(f" Number of contexts: {result['num_contexts']}") print(f" Number of execution groups: {result['num_groups']}") @@ -87,7 +87,7 @@ def example_with_function(): ) if result: - print(f"✓ Reordered {result['num_contexts']} contexts into {result['num_groups']} groups") + print(f"+ Reordered {result['num_contexts']} contexts into {result['num_groups']} groups") print(f"Original indices order: {result['original_indices']}") else: print("Failed to reorder batch") @@ -122,7 +122,7 @@ def example_direct_http(): if response.status_code == 200: result = response.json() - print(f"✓ Reordered successfully!") + print(f"+ Reordered successfully!") print(f" Groups: {result['num_groups']}") print(f" Order: {result['original_indices']}") else: @@ -173,8 +173,8 @@ def batch_processing_workflow(): return execution_order = result['original_indices'] - print(f" ✓ Optimal order: {execution_order}") - print(f" ✓ {result['num_groups']} execution groups") + print(f" + Optimal order: {execution_order}") + print(f" + {result['num_groups']} execution groups") # Step 3: Reorder your data according to the execution order print("\n3. Reordering data for inference...") @@ -219,7 +219,7 @@ def batch_processing_workflow(): response = requests.get("http://localhost:8765/health", timeout=2) if response.status_code == 200: health = response.json() - print(f"\n✓ Server is running (mode: {health.get('mode', 'unknown')})") + print(f"\n+ Server is running (mode: {health.get('mode', 'unknown')})") else: print(f"\n✗ Server returned status {response.status_code}") sys.exit(1) diff --git a/examples/stateless_sglang_e2e.py b/examples/stateless_sglang_e2e.py index b59beb2..f4d2eb5 100644 --- a/examples/stateless_sglang_e2e.py +++ b/examples/stateless_sglang_e2e.py @@ -177,9 +177,9 @@ def run_rag_with_contextpilot( # 2. IDs within each context reordered (shared IDs as prefix) reordered_contexts = schedule_result['reordered_contexts'] num_groups = schedule_result['num_groups'] - print(f" ✓ Optimal order: {scheduled_order}") - print(f" ✓ Grouped into {num_groups} execution groups") - print(f" ✓ Document IDs reordered within each context for prefix sharing") + print(f" + Optimal order: {scheduled_order}") + print(f" + Grouped into {num_groups} execution groups") + print(f" + Document IDs reordered within each context for prefix sharing") else: print(" ⚠ ContextPilot unavailable, using original order") scheduled_order = list(range(n)) @@ -214,7 +214,7 @@ def run_rag_with_contextpilot( print(f" Generating response {i+1}/{len(prompts)}...", end=" ") response = llm_generate(prompt) responses.append(response) - print("✓") + print("+") # Option B: Batch (uncomment to use) # responses = llm_generate_batch(prompts) @@ -236,7 +236,7 @@ def run_rag_with_contextpilot( 'scheduled_position': scheduled_pos, }) - print(" ✓ Results reordered to match original query order") + print(" + Results reordered to match original query order") return results @@ -255,7 +255,7 @@ def main(): try: r = requests.get(f"{CONTEXTPILOT_URL}/health", timeout=2) - print(f" ContextPilot: ✓ ({r.json().get('mode', 'unknown')} mode)") + print(f" ContextPilot: + ({r.json().get('mode', 'unknown')} mode)") contextpilot_available = True except: print(f" ContextPilot: ✗ Not available at {CONTEXTPILOT_URL}") @@ -263,7 +263,7 @@ def main(): try: r = requests.get(f"{INFERENCE_URL}/health", timeout=2) - print(f" Inference engine: ✓ Ready") + print(f" Inference engine: + Ready") engine_available = True except: print(f" Inference engine: ✗ Not available at {INFERENCE_URL}") diff --git a/pyproject.toml b/pyproject.toml index 9757c7f..21bd07a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ contextpilot-llama-server = "contextpilot._llamacpp_hook:main" [project.optional-dependencies] gpu = ["cupy-cuda12x"] -dev = ["black", "bumpver", "isort", "pip-tools", "pytest", "pytest-cov", "ipython"] +dev = ["black", "bumpver", "isort", "pip-tools", "pytest", "pytest-cov", "ipython", "pytest-asyncio"] sglang = ["sglang>=0.5"] [tool.pytest.ini_options] diff --git a/refactored_plugins/INTEGRATION_GUIDE.md b/refactored_plugins/INTEGRATION_GUIDE.md new file mode 100644 index 0000000..3f2a76a --- /dev/null +++ b/refactored_plugins/INTEGRATION_GUIDE.md @@ -0,0 +1,90 @@ +# Token Proxy Integration Guide (Reuse & Routing Plugins) + +This guide explains how to integrate the **ContextReorder**, **ContextDedup**, and **KVCacheLookup** plugins into the Token Proxy Middleware. + +## Overview + +The plugins are designed to sit between the Agent Framework (OpenClaw) and the LLM Server (SGLang). They process standard OpenAI-formatted JSON payloads to maximize KV cache hits and dynamically route requests to the workers with the highest prefix sharing. + +## Installation + +Ensure the `contextpilot` core logic is installed in your environment: +```bash +pip install -e . +``` + +## Sequential Usage Example + +Here is how to use both plugins in a sequential pipeline (e.g., inside a FastAPI endpoint). + +```python +import asyncio +from refactored_plugins.reorder import ContextReorderPlugin +from refactored_plugins.dedup import ContextDedupPlugin + +# 1. Initialize Plugins +# Tip: Use the same model name as your SGLang backend for accurate tokenization +reorder_plugin = ContextReorderPlugin(model_name="Qwen/Qwen2.5-7B-Instruct") +dedup_plugin = ContextDedupPlugin() + +async def proxy_middleware_endpoint(request_batch: list): + """ + Example middleware logic for a batch of requests. + """ + + # --- STAGE 1: DEDUPLICATION --- + # Process each request in the batch to remove redundant history + compressed_batch = [] + for req in request_batch: + optimized_req = await dedup_plugin.process(req) + compressed_batch.append(optimized_req) + + # --- STAGE 2: REORDERING --- + # Reorder the whole batch to maximize prefix sharing in the KV Cache + final_batch = await reorder_plugin.process(compressed_batch) + + # --- TELEMETRY --- + print(f"Reorder Metrics: {reorder_plugin.get_plugin_metrics()}") + print(f"Dedup Metrics: {dedup_plugin.get_plugin_metrics()}") + + return final_batch + +# Mock Batch for Testing +mock_request = { + "user_id": "user_1", + "messages": [ + {"role": "system", "content": "Tool: Math. Tool: Code."}, + {"role": "user", "content": "Calculate 2+2"} + ] +} + +# Run the pipeline +if __name__ == "__main__": + final = asyncio.run(proxy_middleware_endpoint([mock_request])) + print("Optimization Complete.") +``` + +## Plugin Specifics + +### ContextReorderPlugin +- **Input**: `List[Dict]` (Batch of OpenAI requests). +- **Output**: `List[Dict]` (Reordered batch). +- **Logic**: Clusters requests by content overlap and schedules the execution sequence to ensure adjacent requests share the longest possible prefix. + +### ContextDedupPlugin +- **Input**: `Dict` (Single OpenAI request). +- **Output**: `Dict` (Modified request with history replaced by reference hints). +- **Logic**: Tracks conversation state via `user_id` and `parent_id`. Replaces previously seen messages with strings like `[Reference to Turn 1]`. + +### KVCacheLookupPlugin +- **Input**: `Dict` (Single OpenAI request). +- **Output**: `Dict` (Modified request with `_route_to` key injected containing the target worker endpoint). +- **Logic**: Subscribes to individual worker ZMQ event streams (listening for `BlockStored` and `BlockRemoved` events) to build and maintain a real-time, in-memory **Shadow Radix Tree** representing each worker's GPU KV cache. Upon receiving a request, it tokenizes the prompt, runs a longest-prefix-match search across all worker shadow trees, and routes the request to the worker with the most cached tokens. +- **Dependencies**: Requires `pyzmq` and `msgspec` (installed via core dependencies). + +### SkillAwareContextPlugin +- **Input**: `Dict` (Single OpenAI request). +- **Output**: `Dict` (Modified request with a filtered `tools` array). +- **Logic**: Inspects the request for `_required_skills` (a list of skill names requested by the router). It maps these requested skills to function/tool schemas in a pre-registered `tool_registry` and injects only the relevant tools into the payload's `tools` array, dynamically trimming unused tools to minimize context size. +- **Telemetry**: Tracks the total number of tools filtered out and the last execution duration in milliseconds. + diff --git a/refactored_plugins/__init__.py b/refactored_plugins/__init__.py new file mode 100644 index 0000000..92c7b0b --- /dev/null +++ b/refactored_plugins/__init__.py @@ -0,0 +1 @@ +# Token Proxy Plugins package diff --git a/refactored_plugins/base.py b/refactored_plugins/base.py new file mode 100644 index 0000000..2755def --- /dev/null +++ b/refactored_plugins/base.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class BasePlugin(ABC): + """ + Abstract base class for all Token Proxy Plugins. + Each plugin intercepts a request and performs a specific optimization. + """ + + def __init__(self, name: str): + self.name = name + + @abstractmethod + async def process(self, request_data: Any) -> Any: + """ + Process the incoming request data and return the optimized version. + """ + pass + + @abstractmethod + def get_plugin_metrics(self) -> Dict[str, float]: + """ + Return a dictionary of performance and optimization metrics. + """ + pass + + +class ContextReorderPlugin(BasePlugin): + """ + Plugin for reordering prompts to maximize KV Cache prefix sharing. + Uses ContextPilot clustering and scheduling logic. + """ + + def __init__(self, alpha: float = 0.001, use_gpu: bool = False): + super().__init__("context_reorder") + from contextpilot.server.live_index import ContextPilot + + self.pilot = ContextPilot(alpha=alpha, use_gpu=use_gpu) + + async def process(self, request_batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Specialized process for batches. + Note: The Proxy framework will need to handle batching. + """ + # Extract prompts/messages from batch + # This is where we'll bridge the OpenAI format to ContextPilot's token lists + pass diff --git a/refactored_plugins/dedup.py b/refactored_plugins/dedup.py new file mode 100644 index 0000000..104a81b --- /dev/null +++ b/refactored_plugins/dedup.py @@ -0,0 +1,92 @@ +import logging +import time +import uuid +from typing import Any, Dict, List, Optional +from .base import BasePlugin + +logger = logging.getLogger(__name__) + + +class ContextDedupPlugin(BasePlugin): + """ + Plugin for deduplicating redundant conversational history in multi-turn requests. + Uses ContextPilot's ConversationTracker to replace repeated messages with reference hints. + """ + + def __init__(self, hint_template: str = "[Reference to Turn {turn_number}]"): + super().__init__("context_dedup") + from contextpilot.server.conversation_tracker import ConversationTracker + + self.tracker = ConversationTracker(hint_template=hint_template) + self._content_to_id = {} + self._next_id = 0 + + # Telemetry + self.total_chars_saved = 0 + self.total_requests_processed = 0 + self.last_execution_time_ms = 0.0 + + def _get_id(self, content: str) -> int: + """Map message content to a unique integer ID.""" + if content not in self._content_to_id: + self._content_to_id[content] = self._next_id + self._next_id += 1 + return self._content_to_id[content] + + async def process(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Deduplicate a single OpenAI request. + """ + messages = request_data.get("messages", []) + if not messages: + return request_data + + start_time = time.perf_counter() + + # Calculate original char length for telemetry + original_len = sum(len(m.get("content", "")) for m in messages) + + conv_id = request_data.get("user_id", "default_session") + parent_id = request_data.get("parent_id") + + # 1. Convert messages to IDs + message_ids = [self._get_id(m.get("content", "")) for m in messages] + + # 2. Run Deduplication + current_req_id = str(uuid.uuid4()) + result = self.tracker.deduplicate(request_id=current_req_id, docs=message_ids, parent_request_id=parent_id) + + # 3. Reconstruct messages with hints + new_messages = [] + for i, m in enumerate(messages): + msg_id = message_ids[i] + if msg_id in result.overlapping_docs: + hint_idx = result.overlapping_docs.index(msg_id) + hint_text = result.reference_hints[hint_idx] + new_messages.append({"role": m.get("role"), "content": hint_text}) + else: + new_messages.append(m) + + # Update Request + optimized_request = dict(request_data) + optimized_request["messages"] = new_messages + optimized_request["current_id"] = current_req_id + + # Update Telemetry + dedup_len = sum(len(m.get("content", "")) for m in new_messages) + self.total_chars_saved += original_len - dedup_len + self.total_requests_processed += 1 + self.last_execution_time_ms = (time.perf_counter() - start_time) * 1000 + + logger.info( + f"Deduplicated request in {self.last_execution_time_ms:.2f}ms. Saved {original_len - dedup_len} chars." + ) + return optimized_request + + def get_plugin_metrics(self) -> Dict[str, float]: + """Return deduplication metrics.""" + return { + "total_chars_saved": float(self.total_chars_saved), + "total_requests_processed": float(self.total_requests_processed), + "last_execution_time_ms": self.last_execution_time_ms, + } diff --git a/refactored_plugins/kv_lookup.py b/refactored_plugins/kv_lookup.py new file mode 100644 index 0000000..e0e65d9 --- /dev/null +++ b/refactored_plugins/kv_lookup.py @@ -0,0 +1,165 @@ +import asyncio +import logging +import zmq +import zmq.asyncio +import msgspec +from typing import Any, Dict, List, Optional +from .base import BasePlugin + +logger = logging.getLogger(__name__) + + +class ShadowRadixTree: + """ + Maintains a shadow copy of the worker's KV cache Radix Tree state + by mapping block hashes to their parents and token contents. + """ + + def __init__(self): + # block_hash -> {"parent": parent_hash, "tokens": token_ids} + self.state: Dict[int, Dict[str, Any]] = {} + # Cache to speed up prefix matching queries + self._full_tokens_cache: Dict[int, List[int]] = {} + + def add_block(self, block_hash: int, parent_hash: Optional[int], token_ids: List[int]): + self.state[block_hash] = {"parent": parent_hash, "tokens": token_ids} + self._full_tokens_cache.clear() + + def remove_block(self, block_hash: int): + if block_hash in self.state: + del self.state[block_hash] + self._full_tokens_cache.clear() + + def _get_tokens(self, block_hash: int) -> List[int]: + if block_hash in self._full_tokens_cache: + return self._full_tokens_cache[block_hash] + + path = [] + curr = block_hash + while curr is not None and curr in self.state: + path.append(curr) + curr = self.state[curr]["parent"] + + tokens = [] + for b_hash in reversed(path): + tokens.extend(self.state[b_hash]["tokens"]) + + self._full_tokens_cache[block_hash] = tokens + return tokens + + def longest_prefix_match(self, target_token_ids: List[int]) -> int: + """ + Traverses from root blocks down to find how many tokens match the target. + Returns the maximum matched tokens. + """ + best_match = 0 + for block_hash in self.state: + node_tokens = self._get_tokens(block_hash) + match_len = 0 + for t1, t2 in zip(target_token_ids, node_tokens): + if t1 == t2: + match_len += 1 + else: + break + if match_len > best_match: + best_match = match_len + return best_match + + +class KVCacheLookupPlugin(BasePlugin): + """ + Plugin for routing requests to the worker with the highest KV cache prefix match. + Subscribes to worker ZMQ streams to maintain shadow Radix trees. + """ + + def __init__(self, endpoints: List[str], model_name: str = "Qwen/Qwen2.5-7B-Instruct"): + super().__init__("kv_cache_lookup") + from contextpilot.utils.prompt_generator import get_tokenizer + + self.tokenizer = get_tokenizer(model_name) + if self.tokenizer is None: + logger.warning(f"Could not load tokenizer for {model_name}. Using fallback char-split.") + + self.endpoints = endpoints + self.trees: Dict[str, ShadowRadixTree] = {endpoint: ShadowRadixTree() for endpoint in endpoints} + + self.ctx = zmq.asyncio.Context() + self.listener_tasks = [] + + # Spawn ZMQ listener tasks for each endpoint + for endpoint in endpoints: + task = asyncio.create_task(self._listen(endpoint)) + self.listener_tasks.append(task) + + async def _listen(self, endpoint: str): + sub = self.ctx.socket(zmq.SUB) + sub.connect(endpoint) + sub.setsockopt_string(zmq.SUBSCRIBE, "") + + tree = self.trees[endpoint] + + while True: + try: + parts = await sub.recv_multipart() + # SGLang format: topic, seq, msgpack_payload + if len(parts) >= 3: + payload = parts[2] + elif len(parts) == 1: + payload = parts[0] + else: + continue + + event = msgspec.msgpack.decode(payload) + event_type = event.get("type") or event.get("event_type") + + if event_type == "BlockStored": + block_hash = event.get("block_hash") + parent_hash = event.get("parent_block_hash") + token_ids = event.get("token_ids", []) + if block_hash is not None: + tree.add_block(block_hash, parent_hash, token_ids) + + elif event_type == "BlockRemoved": + block_hash = event.get("block_hash") + if block_hash is not None: + tree.remove_block(block_hash) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in ZMQ listener for {endpoint}: {e}") + + def _tokenize(self, text: str) -> List[int]: + if self.tokenizer: + return self.tokenizer.encode(text, add_special_tokens=False) + return [ord(c) for c in text] + + async def process(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Tokenize the request messages and query all worker Shadow Radix Trees + for the longest prefix match. Inject '_route_to' into the request. + """ + messages = request_data.get("messages", []) + if not messages: + return request_data + + full_text = "\n".join([m.get("content", "") for m in messages]) + target_tokens = self._tokenize(full_text) + + best_endpoint = None + max_match = -1 + + for endpoint, tree in self.trees.items(): + match_len = tree.longest_prefix_match(target_tokens) + if match_len > max_match: + max_match = match_len + best_endpoint = endpoint + + optimized_request = dict(request_data) + if best_endpoint: + optimized_request["_route_to"] = best_endpoint + + return optimized_request + + def get_plugin_metrics(self) -> Dict[str, float]: + return {} diff --git a/refactored_plugins/reorder.py b/refactored_plugins/reorder.py new file mode 100644 index 0000000..1d1ccd1 --- /dev/null +++ b/refactored_plugins/reorder.py @@ -0,0 +1,70 @@ +import logging +import time +from typing import Any, Dict, List, Optional +from .base import BasePlugin + +logger = logging.getLogger(__name__) + + +class ContextReorderPlugin(BasePlugin): + """ + Plugin for reordering prompts to maximize KV Cache prefix sharing. + Optimized for OpenAI-formatted request batches. + """ + + def __init__(self, model_name: str = "Qwen/Qwen2.5-7B-Instruct", alpha: float = 0.001, use_gpu: bool = False): + super().__init__("context_reorder") + from contextpilot.server.live_index import ContextPilot + from contextpilot.utils.prompt_generator import get_tokenizer + + self.pilot = ContextPilot(alpha=alpha, use_gpu=use_gpu, linkage_method="single") + self.pilot.num_workers = 1 + self.tokenizer = get_tokenizer(model_name) + if self.tokenizer is None: + logger.warning(f"Could not load tokenizer for {model_name}. Using fallback char-split.") + + # Telemetry + self.total_processed_batches = 0 + self.last_execution_time_ms = 0.0 + + def _tokenize(self, text: str) -> List[int]: + """Convert text to token IDs using the configured tokenizer.""" + if self.tokenizer: + return self.tokenizer.encode(text, add_special_tokens=False) + return [ord(c) for c in text] + + async def process(self, request_batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Process a batch of OpenAI requests. + """ + if not request_batch: + return [] + + start_time = time.perf_counter() + + # 1. Extract and Tokenize + tokenized_contexts = [] + for req in request_batch: + full_text = "\n".join([m.get("content", "") for m in req.get("messages", [])]) + tokenized_contexts.append(self._tokenize(full_text)) + + # 2. Run ContextPilot Scheduling + result = self.pilot.build_and_schedule(tokenized_contexts) + + # 3. Reorder the original JSON objects + new_order_indices = result["original_indices"] + reordered_batch = [request_batch[i] for i in new_order_indices] + + # Update Telemetry + self.last_execution_time_ms = (time.perf_counter() - start_time) * 1000 + self.total_processed_batches += 1 + + logger.info(f"Reordered batch of {len(request_batch)} requests in {self.last_execution_time_ms:.2f}ms") + return reordered_batch + + def get_plugin_metrics(self) -> Dict[str, float]: + """Return reordering metrics.""" + return { + "total_processed_batches": float(self.total_processed_batches), + "last_execution_time_ms": self.last_execution_time_ms, + } diff --git a/refactored_plugins/skill_index.py b/refactored_plugins/skill_index.py new file mode 100644 index 0000000..49b4a3b --- /dev/null +++ b/refactored_plugins/skill_index.py @@ -0,0 +1,57 @@ +import logging +import time +from typing import Any, Dict +from .base import BasePlugin + +logger = logging.getLogger(__name__) + + +class SkillAwareContextPlugin(BasePlugin): + """ + Plugin for dynamically injecting OpenAI tool schemas based on the required + skills specified by the framework's router. + """ + + def __init__(self, tool_registry: Dict[str, Dict]): + super().__init__("skill_aware_context") + self.tool_registry = tool_registry + + # Telemetry variables + self.total_tools_filtered = 0 + self.last_execution_time_ms = 0.0 + + async def process(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Process the incoming request and inject the required tool schemas. + """ + start_time = time.perf_counter() + optimized_request = dict(request_data) + + required_skills = optimized_request.get("_required_skills") + if required_skills is not None: + injected_tools = [] + for skill in required_skills: + if skill in self.tool_registry: + injected_tools.append(self.tool_registry[skill]) + else: + logger.warning(f"Skill '{skill}' requested but not found in tool registry.") + + # OpenAI requires a 'tools' array + optimized_request["tools"] = injected_tools + + # Update telemetry + # Number of tools filtered out = total available - total injected + filtered_count = len(self.tool_registry) - len(injected_tools) + self.total_tools_filtered += filtered_count + + self.last_execution_time_ms = (time.perf_counter() - start_time) * 1000 + return optimized_request + + def get_plugin_metrics(self) -> Dict[str, float]: + """ + Return the telemetry data. + """ + return { + "total_tools_filtered": float(self.total_tools_filtered), + "last_execution_time_ms": self.last_execution_time_ms, + } diff --git a/requirements.txt b/requirements.txt index c9b66cf..8df2d2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,5 @@ fastapi[all]>=0.115.0 uvicorn[standard]>=0.32.0 httpx>=0.28.0 pytest +pyzmq +msgspec diff --git a/scripts/benchmark_pageindex.py b/scripts/benchmark_pageindex.py index 9a571d5..fad5a83 100644 --- a/scripts/benchmark_pageindex.py +++ b/scripts/benchmark_pageindex.py @@ -545,7 +545,7 @@ async def run_query_benchmark( 'num_contexts': len(contexts), }) - self._log(f" ✓ {sr['qid']}: {sr['search_time']:.2f}s search, {gen_time:.2f}s gen") + self._log(f" + {sr['qid']}: {sr['search_time']:.2f}s search, {gen_time:.2f}s gen") return { 'results': results, diff --git a/tests/test_dedup_plugin.py b/tests/test_dedup_plugin.py new file mode 100644 index 0000000..7eb2303 --- /dev/null +++ b/tests/test_dedup_plugin.py @@ -0,0 +1,75 @@ + +import pytest +import asyncio +from refactored_plugins.dedup import ContextDedupPlugin + +@pytest.mark.asyncio +async def test_context_dedup_plugin_multi_turn(): + """ + Test: Deduplicate a two-turn conversation. + Turn 1: System + Question 1 + Turn 2: System + Question 1 + Answer 1 + Question 2 + """ + plugin = ContextDedupPlugin() + + # --- TURN 1 --- + system_msg = {"role": "system", "content": "You are a helpful assistant."} + q1 = {"role": "user", "content": "What is the capital of France?"} + + request_t1 = { + "user_id": "user123", + "messages": [system_msg, q1] + } + + # Process Turn 1 (should be no change, but registers the history) + resp_t1 = await plugin.process(request_t1) + assert len(resp_t1["messages"]) == 2 + assert resp_t1["messages"][0]["content"] == system_msg["content"] + + t1_id = resp_t1["current_id"] + + # --- TURN 2 (The Agent framework sends the whole history again) --- + a1 = {"role": "assistant", "content": "The capital of France is Paris."} + q2 = {"role": "user", "content": "And Germany?"} + + request_t2 = { + "user_id": "user123", + "parent_id": t1_id, # Link to Turn 1 + "messages": [ + system_msg, # Duplicate + q1, # Duplicate + a1, # New + q2 # New + ] + } + + # Calculate original length + original_total_len = sum(len(m["content"]) for m in request_t2["messages"]) + + # Process Turn 2 + resp_t2 = await plugin.process(request_t2) + + # Calculate deduplicated length + dedup_total_len = sum(len(m["content"]) for m in resp_t2["messages"]) + + # --- ASSERTIONS --- + assert len(resp_t2["messages"]) == 4 + + # The first two messages should now be hints + assert "[Reference to Turn" in resp_t2["messages"][0]["content"] + assert "[Reference to Turn" in resp_t2["messages"][1]["content"] + + # The new messages should remain intact + assert resp_t2["messages"][2]["content"] == a1["content"] + assert resp_t2["messages"][3]["content"] == q2["content"] + + # Compression check + assert dedup_total_len < original_total_len + + print(f"\nOriginal Length: {original_total_len} chars") + print(f"Dedup Length: {dedup_total_len} chars") + print(f"Compression: {(1 - dedup_total_len/original_total_len)*100:.2f}%") + print("\nSUCCESS: ContextDedupPlugin compressed multi-turn history using reference hints.") + +if __name__ == "__main__": + asyncio.run(test_context_dedup_plugin_multi_turn()) diff --git a/tests/test_kv_lookup.py b/tests/test_kv_lookup.py new file mode 100644 index 0000000..53ae1fe --- /dev/null +++ b/tests/test_kv_lookup.py @@ -0,0 +1,35 @@ +import pytest +from refactored_plugins.kv_lookup import ShadowRadixTree + + +@pytest.mark.asyncio +async def test_shadow_radix_tree(): + """ + Test the manual adding of overlapping blocks and longest_prefix_match functionality + of the ShadowRadixTree. + """ + tree = ShadowRadixTree() + + # 1. Manually add 3 overlapping blocks + # Block 1: parent None + tree.add_block(1, None, [100, 200, 300]) + + # Block 2: parent is Block 1 + tree.add_block(2, 1, [400, 500, 600]) + + # Block 3: parent is Block 2 + tree.add_block(3, 2, [700, 800, 900]) + + # Test 1: Completely new sequence (no match) + assert tree.longest_prefix_match([999, 888]) == 0 + + # Test 2: Partially matching sequence + # Matches tokens from Block 1 and part of Block 2 + assert tree.longest_prefix_match([100, 200, 300, 400, 500, 999]) == 5 + + # Test 3: Fully matching sequence + # Matches all tokens across all blocks + assert tree.longest_prefix_match([100, 200, 300, 400, 500, 600, 700, 800, 900]) == 9 + + # Test 4: Another partial match targeting only Block 1 + assert tree.longest_prefix_match([100, 200, 300, 999]) == 3 diff --git a/tests/test_reorder_plugin.py b/tests/test_reorder_plugin.py new file mode 100644 index 0000000..dfa81f5 --- /dev/null +++ b/tests/test_reorder_plugin.py @@ -0,0 +1,50 @@ + +import pytest +import asyncio +from refactored_plugins.reorder import ContextReorderPlugin + +@pytest.mark.asyncio +async def test_context_reorder_plugin_end_to_end(): + """ + E2E Test: Reorder a batch of OpenAI requests with overlapping system prompts. + """ + # 1. Setup Mock OpenAI Requests + # Group A: Common system prompt + Tool Set A + # Group B: Common system prompt + Tool Set B + system_prompt = "You are an assistant." + tools_a = "Tools for math: add, subtract." + tools_b = "Tools for coding: python, bash." + + requests = [ + {"id": "req1", "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": "How's the weather?"}]}, # No tools + {"id": "req2", "messages": [{"role": "system", "content": system_prompt}, {"role": "system", "content": tools_a}, {"role": "user", "content": "Add 5+5"}]}, # Tools A + {"id": "req3", "messages": [{"role": "system", "content": system_prompt}, {"role": "system", "content": tools_b}, {"role": "user", "content": "Write python script"}]}, # Tools B + {"id": "req4", "messages": [{"role": "system", "content": system_prompt}, {"role": "system", "content": tools_a}, {"role": "user", "content": "Subtract 10-2"}]}, # Tools A + ] + + # 2. Initialize Plugin (using small alpha to detect overlap) + # We use a dummy model name to avoid downloading 7B tokenizer during test (uses fallback) + plugin = ContextReorderPlugin(model_name="test-model", alpha=0.1) + + # 3. Process Batch + reordered_requests = await plugin.process(requests) + + # 4. Assertions + assert len(reordered_requests) == len(requests) + + # In a perfect world, req2 and req4 (Tools A) should be adjacent + # and follow the same prefix logic we saw in the verification script. + + # Let's find the positions of Tools A requests + tools_a_indices = [i for i, r in enumerate(reordered_requests) if "Add 5+5" in str(r) or "Subtract 10-2" in str(r)] + + # Check if they are adjacent + is_adjacent = abs(tools_a_indices[0] - tools_a_indices[1]) == 1 + + print("\nReordered Sequence IDs:", [r["id"] for r in reordered_requests]) + + assert is_adjacent, f"Tools A requests were not grouped together! Indices: {tools_a_indices}" + print("\nSUCCESS: ContextReorderPlugin grouped requests with shared tool definitions.") + +if __name__ == "__main__": + asyncio.run(test_context_reorder_plugin_end_to_end()) diff --git a/tests/test_skill_index.py b/tests/test_skill_index.py new file mode 100644 index 0000000..d01a2ee --- /dev/null +++ b/tests/test_skill_index.py @@ -0,0 +1,40 @@ +import pytest +from refactored_plugins.skill_index import SkillAwareContextPlugin + + +@pytest.fixture +def mock_tool_registry(): + return { + "math": { + "type": "function", + "function": {"name": "math_tool", "description": "Performs mathematical calculations"}, + }, + "weather": { + "type": "function", + "function": {"name": "weather_tool", "description": "Gets the current weather"}, + }, + "database": {"type": "function", "function": {"name": "db_tool", "description": "Queries the database"}}, + } + + +@pytest.mark.asyncio +async def test_skill_aware_context_plugin(mock_tool_registry): + # Initialize the plugin + plugin = SkillAwareContextPlugin(tool_registry=mock_tool_registry) + + # Pass a mock OpenAI request with "_required_skills": ["math"] + mock_request = {"messages": [{"role": "user", "content": "What is 2 + 2?"}], "_required_skills": ["math"]} + + # Process the request + modified_request = await plugin.process(mock_request) + + # Asserts that the returned request contains exactly 1 tool in its "tools" array + assert "tools" in modified_request + assert len(modified_request["tools"]) == 1 + + # Asserts that it is the correct "math" schema + assert modified_request["tools"][0]["function"]["name"] == "math_tool" + + # Asserts that the telemetry correctly tracks that 2 tools were filtered out. + metrics = plugin.get_plugin_metrics() + assert metrics["total_tools_filtered"] == 2.0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 29d4a8a..73e656f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,6 @@ def generate_contexts(num_contexts: int, chunk_ids = list(set(chunk_ids)) contexts.append(chunk_ids) - print(f"✓ Generated {num_contexts:,} contexts") + print(f"+ Generated {num_contexts:,} contexts") print(f" Avg chunks per context: {np.mean([len(c) for c in contexts]):.1f}") return contexts \ No newline at end of file diff --git a/verify_prefix_sharing.py b/verify_prefix_sharing.py new file mode 100644 index 0000000..dc4c63c --- /dev/null +++ b/verify_prefix_sharing.py @@ -0,0 +1,73 @@ + +import time +from typing import List +from contextpilot.server.live_index import ContextPilot, compute_prefix_length + +def calculate_total_prefix_sharing(contexts: List[List[int]]) -> int: + if not contexts: + return 0 + + total_sharing = 0 + for i in range(1, len(contexts)): + shared = compute_prefix_length(contexts[i-1], contexts[i]) + total_sharing += shared + return total_sharing + +def run_verification(): + print("=" * 60) + print("VERIFYING PREFIX SHARING BENEFIT") + print("=" * 60) + + # 1. Generate overlapping synthetic contexts + system_prompt = list(range(1, 101)) # 100 + tool_defs = list(range(101, 201)) # 100 + + contexts = [] + for i in range(20): + if i < 10: + ctx = system_prompt + tool_defs + [1000 + i] + else: + ctx = system_prompt + [2000 + i] + contexts.append(ctx) + + # 2. Calculate sharing BEFORE reordering + import random + random.seed(42) + shuffled_contexts = list(contexts) + random.shuffle(shuffled_contexts) + + sharing_before = calculate_total_prefix_sharing(shuffled_contexts) + + # 3. Apply ContextPilot reordering (using the full Pilot) + pilot = ContextPilot(use_gpu=False, alpha=0.1) + # build_and_schedule returns reordered_contexts in the scheduled order + result = pilot.build_and_schedule(shuffled_contexts) + reordered_contexts = result["reordered_contexts"] + + sharing_after = calculate_total_prefix_sharing(reordered_contexts) + + # 4. Debug: Check the first few reordered contexts + print("\nDEBUG: Reordered Context Structure") + for i in range(min(10, len(reordered_contexts))): + ctx = reordered_contexts[i] + has_tools = all(t in ctx for t in tool_defs[:10]) + print(f" Context {i}: len={len(ctx)}, has_tool_defs={has_tools}") + + # 5. Results + print(f"\nBatch Size: {len(contexts)} contexts") + print(f"Sharing BEFORE: {sharing_before} tokens") + print(f"Sharing AFTER: {sharing_after} tokens") + + improvement = (sharing_after - sharing_before) / (sharing_before + 1e-9) * 100 + print(f"Improvement: {improvement:.2f}%") + + if sharing_after > sharing_before: + print("\nSUCCESS: ContextPilot increased prefix sharing!") + elif sharing_after == sharing_before: + print("\nNEUTRAL: No change. Is the baseline already optimal?") + else: + print("\nFAILURE: Reordering decreased sharing!") + print("=" * 60) + +if __name__ == "__main__": + run_verification()