diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 5ce8a436..b43df2b5 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging import requests +import time from datetime import datetime from typing import Any, Dict, List, Optional, Protocol @@ -280,8 +281,9 @@ def get_evaluation_rows( from_timestamp: Optional[datetime] = None, to_timestamp: Optional[datetime] = None, include_tool_calls: bool = True, - sleep_between_gets: float = 2.5, - max_retries: int = 3, + backend_sleep_between_gets: float = 0.1, + backend_max_retries: int = 3, + proxy_max_retries: int = 3, span_name: Optional[str] = None, converter: Optional[TraceDictConverter] = None, ) -> List[EvaluationRow]: @@ -303,8 +305,9 @@ def get_evaluation_rows( from_timestamp: Explicit start time (ISO format) to_timestamp: Explicit end time (ISO format) include_tool_calls: Whether to include tool calling traces - sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) - max_retries: Maximum retries for rate limit errors (handled by proxy) + backend_sleep_between_gets: Sleep time between backend trace fetches (passed to proxy) + backend_max_retries: Maximum retries for backend operations (passed to proxy) + proxy_max_retries: Maximum retries when proxy returns 404 (client-side retries with exponential backoff) span_name: If provided, extract messages from generations within this named span converter: Optional custom converter implementing TraceDictConverter protocol. If provided, this will be used instead of the default conversion logic. @@ -336,25 +339,60 @@ def get_evaluation_rows( "hours_back": hours_back, "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, - "sleep_between_gets": sleep_between_gets, - "max_retries": max_retries, + "sleep_between_gets": backend_sleep_between_gets, + "max_retries": backend_max_retries, } # Remove None values params = {k: v for k, v in params.items() if v is not None} - # Make request to proxy + # Make request to proxy with retry logic if self.project_id: url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" else: url = f"{self.base_url}/v1/traces" - try: - response = requests.get(url, params=params, timeout=self.timeout) - response.raise_for_status() - result = response.json() - except requests.exceptions.RequestException as e: - logger.error("Failed to fetch traces from proxy: %s", e) + # Retry loop for handling backend indexing delays (proxy returns 404) + result = None + for attempt in range(proxy_max_retries): + try: + response = requests.get(url, params=params, timeout=self.timeout) + response.raise_for_status() + result = response.json() + break # Success, exit retry loop + except requests.exceptions.HTTPError as e: + error_msg = str(e) + should_retry = False + + # Try to extract detail message from response + if e.response is not None: + try: + error_detail = e.response.json().get("detail", "") + error_msg = error_detail or e.response.text + + # Retry on 404 if it's due to incomplete/missing traces (backend still indexing) + if e.response.status_code == 404 and ( + "Incomplete traces" in error_detail or "No traces found" in error_detail + ): + should_retry = True + except Exception: + error_msg = e.response.text + + if should_retry and attempt < proxy_max_retries - 1: + sleep_time = 2 ** (attempt + 1) + logger.warning(error_msg) + time.sleep(sleep_time) + else: + # Final retry or non-retryable error + logger.error("Failed to fetch traces from proxy: %s", error_msg) + return eval_rows + except requests.exceptions.RequestException as e: + # Non-HTTP errors (network issues, timeouts, etc.) + logger.error("Failed to fetch traces from proxy: %s", str(e)) + return eval_rows + + if result is None: + logger.error("Failed to fetch traces after %d retries", proxy_max_retries) return eval_rows # Extract traces from response diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 1d4b6553..c20b4e21 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -58,7 +58,7 @@ def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: def fetch_traces() -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -188,7 +188,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") final_model_base_url = model_base_url - if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"): + if model_base_url and ( + model_base_url.startswith("https://tracing.fireworks.ai") + or model_base_url.startswith("http://localhost") + ): final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta) init_payload: InitRequest = InitRequest( diff --git a/tests/remote_server/remote_server_multi_turn.py b/tests/remote_server/remote_server_multi_turn.py new file mode 100644 index 00000000..155a0a2a --- /dev/null +++ b/tests/remote_server/remote_server_multi_turn.py @@ -0,0 +1,104 @@ +import os +import random +import threading + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI +import logging + +from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter + + +app = FastAPI() + +# attach handler to root logger +handler = ElasticsearchDirectHttpHandler() +logging.getLogger().addHandler(handler) + + +@app.post("/init") +def init(req: InitRequest): + if req.elastic_search_config: + handler.configure(req.elastic_search_config) + + # attach rollout_id filter to logger + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + # Kick off worker thread that does a multi-turn chat (6 turns total) + def _worker(): + try: + if not req.messages: + raise ValueError("messages is required") + + client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) + + # Build up conversation over 6 turns (3 user messages + 3 assistant responses) + # Convert Message objects to dicts for OpenAI API + conversation_history = [{"role": m.role, "content": m.content} for m in req.messages] + + follow_up_questions = [ + "Tell me more about that.", + "What else can you share about this topic?", + ] + + # First completion (turns 1-2: initial user message + assistant response) + logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 2 response: {assistant_content[:100]}...") + + # Second completion (turns 3-4: follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[0]}) + logger.info(f"Turn 3: User asks: {follow_up_questions[0]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 4 response: {assistant_content[:100]}...") + + # Third completion (turns 5-6: another follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[1]}) + logger.info(f"Turn 5: User asks: {follow_up_questions[1]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 6 response: {assistant_content[:100]}...") + + logger.info(f"Completed 6-turn conversation with {len(conversation_history)} messages total") + + except Exception as e: + # Best-effort; mark as done even on error to unblock polling + print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") + pass + finally: + logger.info( + f"Rollout {req.metadata.rollout_id} completed", + extra={"status": Status.rollout_finished()}, + ) + + t = threading.Thread(target=_worker, daemon=True) + t.start() + + +def main(): + host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") + port = int(os.getenv("REMOTE_SERVER_PORT", "3000")) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + main() diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index f647fe61..a6a66ba6 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5) def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: @@ -65,7 +65,7 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", - timeout_seconds=30, + timeout_seconds=180, output_data_loader=fireworks_output_data_loader, ), )