Skip to content

Commit 9e36c13

Browse files
author
Shrey Modi
committed
Merge branch 'main' of github.com:eval-protocol/python-sdk into zip-upload
2 parents 750a3e5 + 69dbd1b commit 9e36c13

30 files changed

+1037
-235
lines changed

.github/workflows/elasticsearch-tests.yml renamed to .github/workflows/fireworks-tracing-tests.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Elasticsearch Tests
1+
name: Fireworks Tracing Tests
22

33
on:
44
push:
@@ -13,8 +13,8 @@ on:
1313
workflow_dispatch: # Allow manual triggering
1414

1515
jobs:
16-
elasticsearch-tests:
17-
name: Elasticsearch Integration Tests
16+
fireworks-tracing-tests:
17+
name: Fireworks Tracing Integration Tests
1818
runs-on: ubuntu-latest
1919

2020
steps:
@@ -36,14 +36,15 @@ jobs:
3636
- name: Install the project
3737
run: uv sync --locked --all-extras --dev
3838

39-
- name: Run Elasticsearch Tests
39+
- name: Run Fireworks Tracing Tests
4040
env:
4141
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
4242
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
4343
run: |
44-
# Run Elasticsearch direct HTTP handler tests
45-
uv run pytest tests/logging/test_elasticsearch_direct_http_handler.py -v --tb=short
44+
# Run RemoteRolloutProcessor End-to-End Test (auto server startup)
45+
uv run pytest tests/remote_server/test_remote_fireworks.py::test_remote_rollout_and_fetch_fireworks \
46+
-v --tb=short
4647
47-
# Run RemoteRolloutProcessor Propagate Status Smoke Test (also uses Elasticsearch)
48+
# Run RemoteRolloutProcessor Propagate Status Test (auto server startup)
4849
uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \
4950
-v --tb=short

.github/workflows/rollout.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
55
on:
66
workflow_dispatch:
77
inputs:
8-
model:
9-
description: 'Model to use'
8+
completion_params:
9+
description: 'JSON completion params (optional, includes model_kwargs)'
1010
required: true
1111
type: string
1212
metadata:
@@ -18,6 +18,7 @@ on:
1818
required: true
1919
type: string
2020

21+
2122
jobs:
2223
rollout:
2324
runs-on: ubuntu-latest
@@ -41,6 +42,6 @@ jobs:
4142
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
4243
run: |
4344
python tests/github_actions/rollout_worker.py \
44-
--model "${{ inputs.model }}" \
45+
--completion-params '${{ inputs.completion_params }}' \
4546
--metadata '${{ inputs.metadata }}' \
4647
--model-base-url "${{ inputs.model_base_url }}"

eval_protocol/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
filter_longest_conversation,
3737
)
3838
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
39-
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
4039
from .pytest.parameterize import DefaultParameterIdGenerator
4140
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
4241
from .log_utils.rollout_id_filter import RolloutIdFilter
@@ -90,7 +89,6 @@
9089
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
9190

9291
__all__ = [
93-
"create_elasticsearch_config_from_env",
9492
"ElasticsearchConfig",
9593
"ElasticsearchDirectHttpHandler",
9694
"RolloutIdFilter",

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
309309
"message": e.get("message"),
310310
"severity": e.get("severity", "INFO"),
311311
"tags": e.get("tags", []),
312+
"status": e.get("status"),
312313
}
313314
)
314315
return results

eval_protocol/cli_commands/logs.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -39,49 +39,6 @@ def logs_command(args):
3939
or os.environ.get("GATEWAY_URL")
4040
or "https://tracing.fireworks.ai"
4141
)
42-
try:
43-
if not use_fireworks:
44-
if getattr(args, "use_env_elasticsearch_config", False):
45-
# Use environment variables for configuration
46-
print("⚙️ Using environment variables for Elasticsearch config")
47-
from eval_protocol.pytest.remote_rollout_processor import (
48-
create_elasticsearch_config_from_env,
49-
)
50-
51-
elasticsearch_config = create_elasticsearch_config_from_env()
52-
# Ensure index exists with correct mapping, mirroring Docker setup path
53-
try:
54-
from eval_protocol.log_utils.elasticsearch_index_manager import (
55-
ElasticsearchIndexManager,
56-
)
57-
58-
index_manager = ElasticsearchIndexManager(
59-
elasticsearch_config.url,
60-
elasticsearch_config.index_name,
61-
elasticsearch_config.api_key,
62-
)
63-
created = index_manager.create_logging_index_mapping()
64-
if created:
65-
print(
66-
f"🧭 Verified Elasticsearch index '{elasticsearch_config.index_name}' mapping (created or already correct)"
67-
)
68-
else:
69-
print(
70-
f"⚠️ Could not verify/create mapping for index '{elasticsearch_config.index_name}'. Searches may behave unexpectedly."
71-
)
72-
except Exception as e:
73-
print(f"⚠️ Failed to ensure index mapping via IndexManager: {e}")
74-
elif not getattr(args, "disable_elasticsearch_setup", False):
75-
# Default behavior: start or connect to local Elasticsearch via Docker helper
76-
from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup
77-
78-
print("🧰 Auto-configuring local Elasticsearch (Docker)")
79-
elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()
80-
else:
81-
print("🚫 Elasticsearch setup disabled; running without Elasticsearch integration")
82-
except Exception as e:
83-
print(f"❌ Failed to configure Elasticsearch: {e}")
84-
return 1
8542

8643
try:
8744
serve_logs(
Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
# Global event bus instance - uses SqliteEventBus for cross-process functionality
22
from eval_protocol.event_bus.event_bus import EventBus
3-
from eval_protocol.event_bus.sqlite_event_bus import SqliteEventBus
43

5-
event_bus: EventBus = SqliteEventBus()
4+
5+
def _get_default_event_bus():
6+
from eval_protocol.event_bus.sqlite_event_bus import SqliteEventBus
7+
8+
return SqliteEventBus()
9+
10+
11+
# Lazy property that creates the event bus only when accessed
12+
class _LazyEventBus(EventBus):
13+
def __init__(self):
14+
self._event_bus: EventBus | None = None
15+
16+
def _get_event_bus(self):
17+
if self._event_bus is None:
18+
self._event_bus = _get_default_event_bus()
19+
return self._event_bus
20+
21+
def __getattr__(self, name):
22+
return getattr(self._get_event_bus(), name)
23+
24+
25+
event_bus: EventBus = _LazyEventBus()

eval_protocol/log_utils/fireworks_tracing_http_handler.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,34 @@ def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]:
7979
return str(cast(Any, getattr(record, "rollout_id")))
8080
return os.getenv(self.rollout_id_env)
8181

82+
def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]:
83+
"""Extract status information from the log record's extra data."""
84+
# Check if 'status' is in the extra data (passed via extra parameter)
85+
if hasattr(record, "status") and record.status is not None: # type: ignore
86+
status = record.status # type: ignore
87+
88+
# Handle Status class instances (Pydantic BaseModel)
89+
if hasattr(status, "code") and hasattr(status, "message"):
90+
# Status object - extract code and message
91+
status_code = status.code
92+
# Handle both enum values and direct integer values
93+
if hasattr(status_code, "value"):
94+
status_code = status_code.value
95+
96+
return {
97+
"code": status_code,
98+
"message": status.message,
99+
"details": getattr(status, "details", []),
100+
}
101+
elif isinstance(status, dict):
102+
# Dictionary representation of status
103+
return {
104+
"code": status.get("code"),
105+
"message": status.get("message"),
106+
"details": status.get("details", []),
107+
}
108+
return None
109+
82110
def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]:
83111
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
84112
message = record.getMessage()
@@ -96,28 +124,12 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str
96124
except Exception:
97125
pass
98126
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
99-
status_val = cast(Any, getattr(record, "status", None))
100-
status = status_val if isinstance(status_val, str) else None
101-
# Capture optional structured status fields if present
102-
metadata: Dict[str, Any] = {}
103-
status_code = cast(Any, getattr(record, "status_code", None))
104-
if isinstance(status_code, int):
105-
metadata["status_code"] = status_code
106-
status_message = cast(Any, getattr(record, "status_message", None))
107-
if isinstance(status_message, str):
108-
metadata["status_message"] = status_message
109-
status_details = getattr(record, "status_details", None)
110-
if status_details is not None:
111-
metadata["status_details"] = status_details
112-
extra_metadata = cast(Any, getattr(record, "metadata", None))
113-
if isinstance(extra_metadata, dict):
114-
metadata.update(extra_metadata)
127+
115128
return {
116129
"program": program,
117-
"status": status,
130+
"status": self._get_status_info(record),
118131
"message": message,
119132
"tags": tags,
120-
"metadata": metadata or None,
121133
"extras": {
122134
"logger_name": record.name,
123135
"level": record.levelname,

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from typing import Any, Callable, Dict, List, Optional
5-
5+
import json
66
import requests
77
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
@@ -87,10 +87,14 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8787

8888
def _dispatch_workflow():
8989
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
90+
91+
model = init_request.completion_params.get("model")
92+
if not model:
93+
raise ValueError("model is required in completion_params")
9094
payload = {
9195
"ref": self.ref,
9296
"inputs": {
93-
"model": init_request.model,
97+
"completion_params": json.dumps(init_request.completion_params),
9498
"metadata": init_request.metadata.model_dump_json(),
9599
"model_base_url": init_request.model_base_url,
96100
},

0 commit comments

Comments
 (0)