Skip to content

Commit ce6742f

Browse files
committed
Merge branch 'main' of github.com:eval-protocol/python-sdk into shrey/lazyloading
2 parents 51e5305 + b42c27d commit ce6742f

19 files changed

+412
-87
lines changed

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,17 +253,24 @@ def __init__(
253253
project_id: Optional[str] = None,
254254
base_url: str = "https://tracing.fireworks.ai",
255255
timeout: int = 300,
256+
api_key: Optional[str] = None,
256257
):
257258
"""Initialize the Fireworks Tracing adapter.
258259
259260
Args:
260261
project_id: Optional project ID. If not provided, uses the default project configured on the server.
261262
base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai)
262263
timeout: Request timeout in seconds (default: 300)
264+
api_key: Optional API key. If not provided, falls back to FIREWORKS_API_KEY environment variable.
263265
"""
264266
self.project_id = project_id
265267
self.base_url = base_url.rstrip("/")
266268
self.timeout = timeout
269+
self._api_key = api_key
270+
271+
def _get_api_key(self) -> Optional[str]:
272+
"""Get the API key, preferring instance-level key over environment variable."""
273+
return self._api_key or os.environ.get("FIREWORKS_API_KEY")
267274

268275
def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]:
269276
"""Fetch logs from Fireworks tracing gateway /logs endpoint.
@@ -276,7 +283,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
276283
from ..common_utils import get_user_agent
277284

278285
headers = {
279-
"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}",
286+
"Authorization": f"Bearer {self._get_api_key()}",
280287
"User-Agent": get_user_agent(),
281288
}
282289
params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"}
@@ -407,7 +414,7 @@ def get_evaluation_rows(
407414
from ..common_utils import get_user_agent
408415

409416
headers = {
410-
"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}",
417+
"Authorization": f"Bearer {self._get_api_key()}",
411418
"User-Agent": get_user_agent(),
412419
}
413420

eval_protocol/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,10 @@ def main():
284284
from dotenv import load_dotenv
285285

286286
# .env.dev for development-specific overrides, .env for general
287+
# Use explicit paths to avoid find_dotenv() searching up the directory tree
288+
# and potentially finding a different .env file (e.g., in some other repo)
287289
load_dotenv(dotenv_path=Path(".") / ".env.dev", override=True)
288-
load_dotenv(override=True)
290+
load_dotenv(dotenv_path=Path(".") / ".env", override=True)
289291
except ImportError:
290292
pass
291293

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from eval_protocol.event_bus.sqlite_event_bus_database import (
88
SQLITE_HARDENED_PRAGMAS,
99
check_and_repair_database,
10+
connect_with_retry,
1011
execute_with_sqlite_retry,
1112
)
1213
from eval_protocol.models import EvaluationRow
@@ -42,9 +43,10 @@ class EvaluationRow(BaseModel): # type: ignore
4243

4344
self._EvaluationRow = EvaluationRow
4445

45-
self._db.connect()
46+
# Connect with retry logic that properly handles pragma execution failures
47+
connect_with_retry(self._db)
4648
# Use safe=True to avoid errors when tables/indexes already exist
47-
self._db.create_tables([EvaluationRow], safe=True)
49+
execute_with_sqlite_retry(lambda: self._db.create_tables([EvaluationRow], safe=True))
4850

4951
@property
5052
def db_path(self) -> str:

eval_protocol/event_bus/sqlite_event_bus_database.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
# Retry configuration for database operations
14-
SQLITE_RETRY_MAX_TRIES = 5
15-
SQLITE_RETRY_MAX_TIME = 30 # seconds
14+
SQLITE_RETRY_MAX_TRIES = 10
15+
SQLITE_RETRY_MAX_TIME = 60 # seconds
1616

1717

1818
def _is_database_locked_error(e: Exception) -> bool:
@@ -55,6 +55,49 @@ def _execute() -> T:
5555
return _execute()
5656

5757

58+
def connect_with_retry(db: SqliteDatabase) -> None:
59+
"""
60+
Connect to the database with retry logic, ensuring pragmas are always applied.
61+
62+
Peewee's connect() method sets the connection state *before* executing pragmas
63+
(in _initialize_connection). If pragma execution fails with "database is locked",
64+
the connection is marked as open but pragmas are not applied. Subsequent calls
65+
to connect(reuse_if_open=True) would see the connection as already open and
66+
skip pragma execution entirely.
67+
68+
This function handles this edge case by:
69+
1. Closing the connection if a lock error occurs during connect
70+
2. Retrying with exponential backoff until pragmas are successfully applied
71+
72+
Args:
73+
db: The SqliteDatabase instance to connect
74+
"""
75+
76+
@backoff.on_exception(
77+
backoff.expo,
78+
OperationalError,
79+
max_tries=SQLITE_RETRY_MAX_TRIES,
80+
max_time=SQLITE_RETRY_MAX_TIME,
81+
giveup=lambda e: not _is_database_locked_error(e),
82+
jitter=backoff.full_jitter,
83+
)
84+
def _connect() -> None:
85+
try:
86+
# Close any partially-open connection before retrying to ensure
87+
# a fresh connection is opened and pragmas are executed
88+
if not db.is_closed():
89+
db.close()
90+
db.connect()
91+
except OperationalError:
92+
# If connect fails (e.g., during pragma execution), ensure the
93+
# connection is closed so the next retry starts fresh
94+
if not db.is_closed():
95+
db.close()
96+
raise
97+
98+
_connect()
99+
100+
58101
# SQLite pragmas for hardened concurrency safety
59102
SQLITE_HARDENED_PRAGMAS = {
60103
"journal_mode": "wal", # Write-Ahead Logging for concurrent reads/writes
@@ -181,9 +224,10 @@ class Event(BaseModel): # type: ignore
181224
processed = BooleanField(default=False) # Track if event has been processed
182225

183226
self._Event = Event
184-
self._db.connect()
227+
# Connect with retry logic that properly handles pragma execution failures
228+
connect_with_retry(self._db)
185229
# Use safe=True to avoid errors when tables already exist
186-
self._db.create_tables([Event], safe=True)
230+
execute_with_sqlite_retry(lambda: self._db.create_tables([Event], safe=True))
187231

188232
def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
189233
"""Publish an event to the database."""

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class FunctionLike(BaseModel):
1313
parameters: Any = None
1414

1515

16-
from dotenv import load_dotenv
1716
from mcp import ClientSession, StdioServerParameters
1817
from mcp.client.stdio import stdio_client
1918
from mcp.client.streamable_http import streamablehttp_client
@@ -26,8 +25,6 @@ class FunctionLike(BaseModel):
2625
MCPMultiClientConfiguration,
2726
)
2827

29-
load_dotenv() # load environment variables from .env
30-
3128

3229
class MCPMultiClient:
3330
"""

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from openai.types import CompletionUsage
2323
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2424
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
25+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
2526
from pydantic import BaseModel
2627
from typing import Optional
2728

@@ -251,8 +252,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
251252
"""Process a single row with agent rollout."""
252253
start_time = time.perf_counter()
253254

255+
# Normalize Fireworks model names for LiteLLM routing
256+
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
257+
row.input_metadata.completion_params = completion_params
254258
agent = Agent(
255-
model=row.input_metadata.completion_params["model"],
259+
model=completion_params["model"],
256260
row=row,
257261
config_path=config.mcp_config_path,
258262
logger=config.logger,

eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from eval_protocol.models import EvaluationRow
1212
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1313
from eval_protocol.pytest.types import RolloutProcessorConfig
14+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
1415

1516
from eval_protocol.pytest.default_agent_rollout_processor import Agent
1617
from klavis import Klavis
@@ -30,15 +31,15 @@ def __init__(
3031
self.server_name = server_name
3132
self.initialize_data_factory = initialize_data_factory
3233
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))
33-
34+
3435
def _init_sandbox(self) -> CreateSandboxResponse:
3536
try:
3637
server_name_enum = SandboxMcpServer(self.server_name)
3738
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
3839
except Exception as e:
3940
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
4041
raise
41-
42+
4243
@staticmethod
4344
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
4445
"""Create a temporary MCP config file and return its path."""
@@ -47,26 +48,24 @@ def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str
4748
server_key: {
4849
"url": server_url,
4950
"transport": "streamable_http",
50-
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
51+
**({"authorization": f"Bearer {auth_token}"} if auth_token else {}),
5152
}
5253
}
5354
}
54-
55+
5556
# Create a temp file that persists for the session
5657
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
57-
with os.fdopen(fd, 'w') as f:
58+
with os.fdopen(fd, "w") as f:
5859
json.dump(config, f)
5960
return path
6061

61-
def __call__(
62-
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
63-
) -> List[asyncio.Task[EvaluationRow]]:
62+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
6463
"""Process evaluation rows with Klavis sandbox lifecycle management"""
6564
semaphore = config.semaphore
6665

6766
async def process_row(row: EvaluationRow) -> EvaluationRow:
6867
"""Process a single row with complete sandbox lifecycle"""
69-
68+
7069
start_time = time.perf_counter()
7170
agent: Agent | None = None
7271
temp_config_path: str | None = None
@@ -88,25 +87,32 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8887
if row.input_metadata is not None
8988
else None
9089
)
91-
90+
9291
if init_data:
93-
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}")
92+
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") # pyright: ignore[reportOptionalMemberAccess]
9493
initialize_method = getattr(
95-
self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox"
94+
self.klavis_client.sandbox,
95+
f"initialize_{sandbox.server_name.value}_sandbox", # pyright: ignore[reportOptionalMemberAccess]
9696
)
97-
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data)
97+
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) # pyright: ignore[reportOptionalMemberAccess]
9898
logger.info(f"Initialization response: {init_response}")
99-
99+
100100
# Step 2: Create temporary MCP config with sandbox URL
101101
temp_config_path = self.create_mcp_config(
102-
server_url=sandbox.server_url, server_key=sandbox.server_name.value
102+
server_url=sandbox.server_url, # pyright: ignore[reportOptionalMemberAccess]
103+
server_key=sandbox.server_name.value, # pyright: ignore[reportOptionalMemberAccess]
103104
)
104105
logger.info(f"MCP config created: {temp_config_path}")
105106

106107
# Step 3: Run agent with sandbox MCP server
107-
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
108+
logger.info(
109+
f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox"
110+
)
111+
# Normalize Fireworks model names for LiteLLM routing
112+
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
113+
row.input_metadata.completion_params = completion_params
108114
agent = Agent(
109-
model=row.input_metadata.completion_params["model"],
115+
model=completion_params["model"],
110116
row=row,
111117
config_path=temp_config_path,
112118
logger=config.logger,
@@ -124,16 +130,16 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
124130
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")
125131

126132
# Step 4: Export sandbox data
127-
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox")
128-
dump_response = dump_method(sandbox_id=sandbox.sandbox_id)
133+
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") # pyright: ignore[reportOptionalMemberAccess]
134+
dump_response = dump_method(sandbox_id=sandbox.sandbox_id) # pyright: ignore[reportOptionalMemberAccess]
129135
sandbox_data = dump_response.data
130136
logger.info(f"Sandbox data: {sandbox_data}")
131137

132138
# Store sandbox data in row metadata for evaluation
133139
if not row.execution_metadata.extra:
134140
row.execution_metadata.extra = {}
135141
row.execution_metadata.extra["sandbox_data"] = sandbox_data
136-
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id
142+
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id # pyright: ignore[reportOptionalMemberAccess]
137143
row.execution_metadata.extra["server_name"] = self.server_name
138144

139145
except Exception as e:
@@ -149,7 +155,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
149155
await agent.mcp_client.cleanup()
150156
if temp_config_path and os.path.exists(temp_config_path):
151157
os.unlink(temp_config_path)
152-
158+
153159
# Release sandbox
154160
if sandbox and sandbox.sandbox_id:
155161
try:

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from eval_protocol.models import EvaluationRow
1515
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1616
from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode
17+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
1718

1819

1920
class MCPServerManager:
@@ -280,17 +281,20 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
280281
"Cannot retry without existing server/environments. Call with start_server=True first."
281282
)
282283

283-
model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
284-
temperature = config.completion_params.get("temperature", 0.0)
285-
max_tokens = config.completion_params.get("max_tokens", 4096)
284+
# Normalize Fireworks model names for LiteLLM routing
285+
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
286+
# Update all rows with normalized completion_params
287+
for row in rows:
288+
row.input_metadata.completion_params = completion_params
289+
model_id = str(completion_params.get("model") or "gpt-4o-mini")
290+
temperature = completion_params.get("temperature", 0.0)
291+
max_tokens = completion_params.get("max_tokens", 4096)
286292

287293
# Pass all other completion_params (e.g. stream=True) via kwargs
288294
other_params = {
289-
k: v
290-
for k, v in (config.completion_params or {}).items()
291-
if k not in ["model", "temperature", "max_tokens", "extra_body"]
295+
k: v for k, v in completion_params.items() if k not in ["model", "temperature", "max_tokens", "extra_body"]
292296
}
293-
extra_body = config.completion_params.get("extra_body", {}) or {}
297+
extra_body = completion_params.get("extra_body", {}) or {}
294298

295299
self.policy = ep.LiteLLMPolicy(
296300
model_id=model_id,

0 commit comments

Comments
 (0)