Skip to content

Commit 60b0fce

Browse files
authored
Klavis Sandbox on Fireworks EP (#388)
* init * update logic
1 parent bd1be95 commit 60b0fce

6 files changed

Lines changed: 343 additions & 1 deletion

File tree

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .default_agent_rollout_processor import AgentRolloutProcessor
22
from .default_dataset_adapter import default_dataset_adapter
3+
from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor
34
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
45
from .default_no_op_rollout_processor import NoOpRolloutProcessor
56
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
@@ -31,6 +32,7 @@
3132

3233
__all__ = [
3334
"AgentRolloutProcessor",
35+
"KlavisSandboxRolloutProcessor",
3436
"MCPGymRolloutProcessor",
3537
"RolloutProcessor",
3638
"SingleTurnRolloutProcessor",
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import asyncio
2+
import json
3+
import logging
4+
import os
5+
import tempfile
6+
import time
7+
from typing import Any, Callable, Dict, List, Optional
8+
9+
from pydantic import BaseModel, Field
10+
11+
from eval_protocol.models import EvaluationRow
12+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
13+
from eval_protocol.pytest.types import RolloutProcessorConfig
14+
15+
from eval_protocol.pytest.default_agent_rollout_processor import Agent
16+
from klavis import Klavis
17+
from klavis.types import CreateSandboxResponse, SandboxMcpServer
18+
from openai.types import CompletionUsage
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class KlavisSandboxRolloutProcessor(RolloutProcessor):
24+
def __init__(
25+
self,
26+
server_name: str,
27+
initialize_data_factory: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
28+
):
29+
super().__init__()
30+
self.server_name = server_name
31+
self.initialize_data_factory = initialize_data_factory
32+
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))
33+
34+
def _init_sandbox(self) -> CreateSandboxResponse:
35+
try:
36+
server_name_enum = SandboxMcpServer(self.server_name)
37+
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
38+
except Exception as e:
39+
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
40+
raise
41+
42+
@staticmethod
43+
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
44+
"""Create a temporary MCP config file and return its path."""
45+
config = {
46+
"mcpServers": {
47+
server_key: {
48+
"url": server_url,
49+
"transport": "streamable_http",
50+
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
51+
}
52+
}
53+
}
54+
55+
# Create a temp file that persists for the session
56+
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
57+
with os.fdopen(fd, 'w') as f:
58+
json.dump(config, f)
59+
return path
60+
61+
def __call__(
62+
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
63+
) -> List[asyncio.Task[EvaluationRow]]:
64+
"""Process evaluation rows with Klavis sandbox lifecycle management"""
65+
semaphore = config.semaphore
66+
67+
async def process_row(row: EvaluationRow) -> EvaluationRow:
68+
"""Process a single row with complete sandbox lifecycle"""
69+
70+
start_time = time.perf_counter()
71+
agent: Agent | None = None
72+
temp_config_path: str | None = None
73+
sandbox: CreateSandboxResponse | None = None
74+
75+
try:
76+
# Step 0: Create a sandbox for this row
77+
sandbox = self._init_sandbox()
78+
logger.info(f"Sandbox created: {sandbox}")
79+
80+
# Step 1: Initialize data in the sandbox
81+
init_data: Dict[str, Any] | None = None
82+
if self.initialize_data_factory:
83+
init_data = self.initialize_data_factory(row)
84+
else:
85+
# Allow datasets to provide initialization payload directly
86+
init_data = (
87+
(row.input_metadata.session_data or {}).get("initialize_data")
88+
if row.input_metadata is not None
89+
else None
90+
)
91+
92+
if init_data:
93+
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}")
94+
initialize_method = getattr(
95+
self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox"
96+
)
97+
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data)
98+
logger.info(f"Initialization response: {init_response}")
99+
100+
# Step 2: Create temporary MCP config with sandbox URL
101+
temp_config_path = self.create_mcp_config(
102+
server_url=sandbox.server_url, server_key=sandbox.server_name.value
103+
)
104+
logger.info(f"MCP config created: {temp_config_path}")
105+
106+
# Step 3: Run agent with sandbox MCP server
107+
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
108+
agent = Agent(
109+
model=row.input_metadata.completion_params["model"],
110+
row=row,
111+
config_path=temp_config_path,
112+
logger=config.logger,
113+
)
114+
await agent.setup()
115+
await agent.call_agent()
116+
117+
# Update usage metadata
118+
row.execution_metadata.usage = CompletionUsage(
119+
prompt_tokens=agent.usage.get("prompt_tokens", 0),
120+
completion_tokens=agent.usage.get("completion_tokens", 0),
121+
total_tokens=agent.usage.get("total_tokens", 0),
122+
)
123+
row = agent.evaluation_row
124+
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")
125+
126+
# Step 4: Export sandbox data
127+
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox")
128+
dump_response = dump_method(sandbox_id=sandbox.sandbox_id)
129+
sandbox_data = dump_response.data
130+
logger.info(f"Sandbox data: {sandbox_data}")
131+
132+
# Store sandbox data in row metadata for evaluation
133+
if not row.execution_metadata.extra:
134+
row.execution_metadata.extra = {}
135+
row.execution_metadata.extra["sandbox_data"] = sandbox_data
136+
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id
137+
row.execution_metadata.extra["server_name"] = self.server_name
138+
139+
except Exception as e:
140+
logger.error(f"Error processing row {row.execution_metadata.rollout_id}: {str(e)}", exc_info=True)
141+
if not row.execution_metadata.extra:
142+
row.execution_metadata.extra = {}
143+
row.execution_metadata.extra["error"] = str(e)
144+
raise
145+
146+
finally:
147+
# Cleanup agent MCP client and temp config
148+
if agent and agent.mcp_client:
149+
await agent.mcp_client.cleanup()
150+
if temp_config_path and os.path.exists(temp_config_path):
151+
os.unlink(temp_config_path)
152+
153+
# Release sandbox
154+
if sandbox and sandbox.sandbox_id:
155+
try:
156+
self.klavis_client.sandbox.delete_sandbox(
157+
server_name=sandbox.server_name, sandbox_id=sandbox.sandbox_id
158+
)
159+
logger.info(f"Sandbox {sandbox.sandbox_id} released successfully")
160+
except Exception as e:
161+
logger.error(f"Error releasing sandbox {sandbox.sandbox_id}: {str(e)}", exc_info=True)
162+
163+
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
164+
165+
return row
166+
167+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
168+
async with semaphore:
169+
result = await process_row(r)
170+
return result
171+
172+
# Create and return tasks
173+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
174+
return tasks

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ openenv = [
134134
dspy = [
135135
"dspy>=3.0.0",
136136
]
137+
klavis = [
138+
"klavis>=2.18.0",
139+
]
137140

138141
# Optional deps for LangGraph example/tests
139142
langgraph = [
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"initialize_data": {"messages": [{"subject": "Project Update", "to": "zihao@klavisai.com", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "sarah@klavisai.com", "reply_to": "", "labels": ["INBOX"]}, {"subject": "Spam Newsletter", "to": "zihao@klavisai.com", "body": "Check out our amazing deals! Click here now!", "cc": "", "bcc": "", "from": "marketing@spammy.com", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}, "messages": "Please delete the email with subject \"Spam Newsletter\" from my inbox.", "ground_truth": {"messages": [{"subject": "Project Update", "to": "zihao@klavisai.com", "body": "The project is progressing well. We should have the final deliverables by next week.", "cc": "", "bcc": "", "from": "sarah@klavisai.com", "reply_to": "", "labels": ["INBOX"]}], "drafts": []}}
2+
{"initialize_data": {"messages": [], "drafts": []}, "messages": "Please directly send an email to zihao@klavisai.com with subject \"Meeting Tomorrow\" and body \"Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.\"", "ground_truth": {"messages": [{"subject": "Meeting Tomorrow", "to": "zihao@klavisai.com", "body": "Hi Zihao, just confirming our meeting tomorrow at 2pm. Best regards.", "cc": "", "bcc": "", "from": "", "reply_to": "", "labels": ["SENT"]}], "drafts": []}}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import json
2+
import logging
3+
import os
4+
5+
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
6+
from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test
7+
from openai import AsyncOpenAI
8+
from pydantic import BaseModel
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class ResponseFormat(BaseModel):
14+
score: float
15+
16+
17+
def klavis_gmail_sandbox_dataset_adapter(rows: list[dict]) -> list[EvaluationRow]:
18+
"""Dataset adapter for sandbox JSONL rows.
19+
20+
Supports the new schema:
21+
- initialize_data: dict (passed to Klavis sandbox initializer)
22+
- messages: str (task instruction)
23+
- ground_truth: dict (expected final sandbox state)
24+
25+
"""
26+
adapted: list[EvaluationRow] = []
27+
system_prompt = (
28+
"You are a helpful assistant with access to Gmail. "
29+
"You can send emails, draft emails, and manage messages, etc."
30+
)
31+
32+
for r in rows:
33+
if isinstance(r.get("messages"), str) and "initialize_data" in r:
34+
init_data = r.get("initialize_data") or {}
35+
task = r.get("messages") or ""
36+
ground_truth = r.get("ground_truth")
37+
38+
row = EvaluationRow(
39+
messages=[
40+
Message(role="system", content=system_prompt),
41+
Message(role="user", content=task),
42+
],
43+
ground_truth=ground_truth,
44+
)
45+
row.input_metadata.session_data = {
46+
"initialize_data": init_data,
47+
"task": task,
48+
}
49+
adapted.append(row)
50+
else:
51+
adapted.append(EvaluationRow(**r))
52+
53+
return adapted
54+
55+
56+
@evaluation_test(
57+
input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"],
58+
rollout_processor=KlavisSandboxRolloutProcessor(
59+
server_name="gmail",
60+
),
61+
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-thinking"}],
62+
mode="pointwise",
63+
dataset_adapter=klavis_gmail_sandbox_dataset_adapter,
64+
)
65+
async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow:
66+
"""
67+
Evaluate Gmail sandbox results by comparing with ground truth using LLM judge.
68+
69+
The sandbox data is exported after agent execution and compared with expected output.
70+
Sandbox data is available in row.execution_metadata.extra["sandbox_data"].
71+
"""
72+
ground_truth = row.ground_truth
73+
sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {}
74+
final_message = row.messages[-1].content if row.messages else ""
75+
initialize_data = (row.input_metadata.session_data or {}).get("initialize_data", {})
76+
task = (row.input_metadata.session_data or {}).get("task", "")
77+
78+
logger.info(f"Evaluating row {row.execution_metadata.rollout_id}")
79+
logger.info(f"Final message: {final_message}")
80+
logger.info(f"Sandbox data: {json.dumps(sandbox_data, indent=2, default=str)}")
81+
logger.info(f"Ground truth: {ground_truth}")
82+
83+
async with AsyncOpenAI(
84+
api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1"
85+
) as client:
86+
87+
evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail sandbox task.
88+
89+
Task:
90+
{task or (row.messages[-1].content if row.messages else 'N/A')}
91+
92+
Initial Gmail Sandbox State (initialize_data):
93+
{json.dumps(initialize_data, indent=2, default=str)}
94+
95+
Expected Final Gmail Sandbox State (ground_truth):
96+
{json.dumps(ground_truth, indent=2, default=str)}
97+
98+
Gmail Sandbox State After Execution:
99+
{json.dumps(sandbox_data, indent=2, default=str)}
100+
101+
Evaluate whether the agent successfully completed the task by checking:
102+
1. Does the final sandbox state match the expected ground_truth state?
103+
2. If there are small formatting differences, judge semantically
104+
3. Use the initial state only as context; the key is whether the correct changes happened.
105+
106+
Return:
107+
- score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed
108+
109+
"""
110+
111+
try:
112+
response = await client.chat.completions.create(
113+
model="accounts/fireworks/models/kimi-k2-thinking",
114+
messages=[
115+
{
116+
"role": "system",
117+
"content": "You are a precise evaluator of AI agent performance. Analyze the task, execution, and results carefully.",
118+
},
119+
{"role": "user", "content": evaluation_prompt},
120+
],
121+
response_format={
122+
"type": "json_schema",
123+
"json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()},
124+
},
125+
temperature=0.0,
126+
)
127+
128+
response_text = response.choices[0].message.content
129+
logger.info(f"LLM judge response: {response_text}")
130+
131+
parsed = json.loads(response_text or "{}")
132+
score = parsed.get("score", 0.0)
133+
134+
row.evaluation_result = EvaluateResult(score=score)
135+
except Exception as e:
136+
logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True)
137+
row.evaluation_result = EvaluateResult(
138+
score=0.0,
139+
reason=f"Evaluation error: {str(e)}",
140+
)
141+
142+
return row

uv.lock

Lines changed: 20 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)