Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
filter_longest_conversation,
)
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
from .pytest.parameterize import DefaultParameterIdGenerator
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
from .log_utils.rollout_id_filter import RolloutIdFilter
Expand Down Expand Up @@ -90,7 +89,6 @@
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

__all__ = [
"create_elasticsearch_config_from_env",
"ElasticsearchConfig",
"ElasticsearchDirectHttpHandler",
"RolloutIdFilter",
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
"message": e.get("message"),
"severity": e.get("severity", "INFO"),
"tags": e.get("tags", []),
"status": e.get("status"),
}
)
return results
Expand Down
43 changes: 0 additions & 43 deletions eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,49 +39,6 @@ def logs_command(args):
or os.environ.get("GATEWAY_URL")
or "https://tracing.fireworks.ai"
)
try:
if not use_fireworks:
if getattr(args, "use_env_elasticsearch_config", False):
# Use environment variables for configuration
print("⚙️ Using environment variables for Elasticsearch config")
from eval_protocol.pytest.remote_rollout_processor import (
create_elasticsearch_config_from_env,
)

elasticsearch_config = create_elasticsearch_config_from_env()
# Ensure index exists with correct mapping, mirroring Docker setup path
try:
from eval_protocol.log_utils.elasticsearch_index_manager import (
ElasticsearchIndexManager,
)

index_manager = ElasticsearchIndexManager(
elasticsearch_config.url,
elasticsearch_config.index_name,
elasticsearch_config.api_key,
)
created = index_manager.create_logging_index_mapping()
if created:
print(
f"🧭 Verified Elasticsearch index '{elasticsearch_config.index_name}' mapping (created or already correct)"
)
else:
print(
f"⚠️ Could not verify/create mapping for index '{elasticsearch_config.index_name}'. Searches may behave unexpectedly."
)
except Exception as e:
print(f"⚠️ Failed to ensure index mapping via IndexManager: {e}")
elif not getattr(args, "disable_elasticsearch_setup", False):
# Default behavior: start or connect to local Elasticsearch via Docker helper
from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup

print("🧰 Auto-configuring local Elasticsearch (Docker)")
elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()
else:
print("🚫 Elasticsearch setup disabled; running without Elasticsearch integration")
except Exception as e:
print(f"❌ Failed to configure Elasticsearch: {e}")
return 1

try:
serve_logs(
Expand Down
48 changes: 30 additions & 18 deletions eval_protocol/log_utils/fireworks_tracing_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]:
return str(cast(Any, getattr(record, "rollout_id")))
return os.getenv(self.rollout_id_env)

def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]:
"""Extract status information from the log record's extra data."""
# Check if 'status' is in the extra data (passed via extra parameter)
if hasattr(record, "status") and record.status is not None: # type: ignore
status = record.status # type: ignore

# Handle Status class instances (Pydantic BaseModel)
if hasattr(status, "code") and hasattr(status, "message"):
# Status object - extract code and message
status_code = status.code
# Handle both enum values and direct integer values
if hasattr(status_code, "value"):
status_code = status_code.value

return {
"status_code": status_code,
"status_message": status.message,
"status_details": getattr(status, "details", []),
}
elif isinstance(status, dict):
# Dictionary representation of status
return {
"status_code": status.get("code"),
"status_message": status.get("message"),
"status_details": status.get("details", []),
}
return None

def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]:
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
message = record.getMessage()
Expand All @@ -96,28 +124,12 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str
except Exception:
pass
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
status_val = cast(Any, getattr(record, "status", None))
status = status_val if isinstance(status_val, str) else None
# Capture optional structured status fields if present
metadata: Dict[str, Any] = {}
status_code = cast(Any, getattr(record, "status_code", None))
if isinstance(status_code, int):
metadata["status_code"] = status_code
status_message = cast(Any, getattr(record, "status_message", None))
if isinstance(status_message, str):
metadata["status_message"] = status_message
status_details = getattr(record, "status_details", None)
if status_details is not None:
metadata["status_details"] = status_details
extra_metadata = cast(Any, getattr(record, "metadata", None))
if isinstance(extra_metadata, dict):
metadata.update(extra_metadata)

return {
"program": program,
"status": status,
"status": self._get_status_info(record),
"message": message,
"tags": tags,
"metadata": metadata or None,
"extras": {
"logger_name": record.name,
"level": record.levelname,
Expand Down
94 changes: 30 additions & 64 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

import requests

from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import (
DataLoaderConfig,
ElasticsearchConfig,
)
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .elasticsearch_setup import ElasticsearchSetup
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
import logging

Expand All @@ -22,25 +21,6 @@
logger = logging.getLogger(__name__)


def create_elasticsearch_config_from_env() -> ElasticsearchConfig:
"""Setup Elasticsearch config from environment variables."""
url = os.getenv("ELASTICSEARCH_URL")
api_key = os.getenv("ELASTICSEARCH_API_KEY")
index_name = os.getenv("ELASTICSEARCH_INDEX_NAME")

if url is None:
raise ValueError("ELASTICSEARCH_URL must be set")
if api_key is None:
raise ValueError("ELASTICSEARCH_API_KEY must be set")
if index_name is None:
raise ValueError("ELASTICSEARCH_INDEX_NAME must be set")
return ElasticsearchConfig(
url=url,
api_key=api_key,
index_name=index_name,
)


class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.
Expand All @@ -59,8 +39,6 @@ def __init__(
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
disable_elastic_search_setup: bool = False,
elastic_search_config: Optional[ElasticsearchConfig] = None,
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
Expand All @@ -74,21 +52,7 @@ def __init__(
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._disable_elastic_search_setup = disable_elastic_search_setup
self._elastic_search_config = elastic_search_config

def setup(self) -> None:
if self._disable_elastic_search_setup:
logger.info("Elasticsearch is disabled, skipping setup")
return
logger.info("Setting up Elasticsearch")
self._elastic_search_config = self._setup_elastic_search()
logger.info("Elasticsearch setup complete")

def _setup_elastic_search(self) -> ElasticsearchConfig:
"""Set up Elasticsearch using the dedicated setup module."""
setup = ElasticsearchSetup()
return setup.setup_elasticsearch()
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
Expand Down Expand Up @@ -123,7 +87,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in RemoteRolloutProcessor")

init_payload = build_init_request(row, config, model_base_url, self._elastic_search_config)
init_payload = build_init_request(row, config, model_base_url)

# Fire-and-poll
def _post_init() -> None:
Expand Down Expand Up @@ -153,10 +117,6 @@ def _get_status() -> Dict[str, Any]:
r.raise_for_status()
return r.json()

elasticsearch_client = (
ElasticsearchClient(self._elastic_search_config) if self._elastic_search_config else None
)

continue_polling_status = True
while time.time() < deadline:
try:
Expand All @@ -178,30 +138,36 @@ def _get_status() -> Dict[str, Any]:
# For all other exceptions, raise them
raise

if not elasticsearch_client:
continue

search_results = elasticsearch_client.search_by_status_code_not_in(
row.execution_metadata.rollout_id, [Status.Code.RUNNING]
# Search Fireworks tracing logs for completion
completed_logs = self._tracing_adapter.search_logs(
tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
)
hits = search_results["hits"]["hits"] if search_results else []
if completed_logs:
latest_log = completed_logs[0]

logger.info(
f"Found completion log for rollout {row.execution_metadata.rollout_id}: {latest_log.get('message', '')}"
)

# Look for structured status dictionary in status field
status_dict = latest_log.get("status")
if status_dict and isinstance(status_dict, dict) and "status_code" in status_dict:
status_code = status_dict.get("status_code")
status_message = status_dict.get("status_message", "")
status_details = status_dict.get("status_details", [])

if hits:
# log all statuses found and update rollout status from the last hit
for hit in hits:
document = hit["_source"]
logger.info(
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
)
# Update rollout status from the document
if "status_code" in document:
row.rollout_status = Status(
code=Status.Code(document["status_code"]),
message=document.get("status_message", ""),
details=document.get("status_details", []),
)
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
break

row.rollout_status = Status(
code=Status.Code(status_code),
message=status_message,
details=status_details,
)

logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
break

await asyncio.sleep(poll_interval)
else:
Expand Down
2 changes: 0 additions & 2 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def build_init_request(
row: EvaluationRow,
config: RolloutProcessorConfig,
model_base_url: str,
elastic_search_config: Optional[Any] = None,
) -> InitRequest:
"""Build an InitRequest from an EvaluationRow and config (shared logic)."""
# Validation
Expand Down Expand Up @@ -129,7 +128,6 @@ def build_init_request(
tools=row.tools,
metadata=meta,
model_base_url=final_model_base_url,
elastic_search_config=elastic_search_config,
)


Expand Down
13 changes: 5 additions & 8 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,22 @@
from openai import OpenAI
import logging

from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter
from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter


app = FastAPI()

# attach handler to root logger
handler = ElasticsearchDirectHttpHandler()
logging.getLogger().addHandler(handler)
# Attach Fireworks tracing handler to root logger
fireworks_handler = FireworksTracingHttpHandler()
logging.getLogger().addHandler(fireworks_handler)


force_early_error_message = None


@app.post("/init")
def init(req: InitRequest):
if req.elastic_search_config:
handler.configure(req.elastic_search_config)

# attach rollout_id filter to logger
# Attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))

Expand Down
13 changes: 5 additions & 8 deletions tests/remote_server/remote_server_multi_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@
from openai import OpenAI
import logging

from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter
from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter


app = FastAPI()

# attach handler to root logger
handler = ElasticsearchDirectHttpHandler()
logging.getLogger().addHandler(handler)
# Attach Fireworks tracing handler to root logger
fireworks_handler = FireworksTracingHttpHandler()
logging.getLogger().addHandler(fireworks_handler)


@app.post("/init")
def init(req: InitRequest):
if req.elastic_search_config:
handler.configure(req.elastic_search_config)

# attach rollout_id filter to logger
# Attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))

Expand Down
Loading