Skip to content

Commit 9d66e5f

Browse files
author
Dylan Huang
committed
Add Elasticsearch setup and configuration to RemoteRolloutProcessor
- Introduced setup method in RemoteRolloutProcessor to initialize Elasticsearch if not disabled. - Added ElasticSearchConfig model for managing Elasticsearch configuration. - Implemented logic to parse environment variables from a .env file and start Elasticsearch if necessary. - Updated evaluation_test to call rollout_processor.setup() for proper initialization. - Modified RolloutProcessor to include a setup method for potential overrides in subclasses.
1 parent b99fcc5 commit 9d66e5f

6 files changed

Lines changed: 197 additions & 3 deletions

File tree

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,5 +239,3 @@ package-lock.json
239239
package.json
240240
tau2-bench
241241
*.err
242-
243-
elastic-start-local/

eval_protocol/pytest/evaluation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
367367
exception_handler_config=exception_handler_config,
368368
)
369369

370+
rollout_processor.setup()
371+
370372
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
371373
nonlocal all_results
372374

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import asyncio
2+
import subprocess
23
import time
34
from typing import Any, Dict, List, Optional, Callable
45

6+
from dotenv import load_dotenv
57
import requests
68

9+
from eval_protocol.directory_utils import find_eval_protocol_dir
710
from eval_protocol.models import EvaluationRow, Status
811
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9-
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
12+
from eval_protocol.types.remote_rollout_processor import ElasticSearchConfig, InitRequest, RolloutMetadata
1013
from .rollout_processor import RolloutProcessor
1114
from .types import RolloutProcessorConfig
15+
import logging
16+
1217
import os
1318

19+
logger = logging.getLogger(__name__)
20+
1421

1522
class RemoteRolloutProcessor(RolloutProcessor):
1623
"""
@@ -27,6 +34,8 @@ def __init__(
2734
poll_interval: float = 1.0,
2835
timeout_seconds: float = 120.0,
2936
output_data_loader: Callable[[str], DynamicDataLoader],
37+
disable_elastic_search: bool = False,
38+
elastic_search_config: Optional[ElasticSearchConfig] = None,
3039
):
3140
# Prefer constructor-provided configuration. These can be overridden via
3241
# config.kwargs at call time for backward compatibility.
@@ -37,6 +46,58 @@ def __init__(
3746
self._poll_interval = poll_interval
3847
self._timeout_seconds = timeout_seconds
3948
self._output_data_loader = output_data_loader
49+
self._disable_elastic_search = disable_elastic_search
50+
self._elastic_search_config = elastic_search_config
51+
52+
def setup(self) -> None:
53+
if self._disable_elastic_search:
54+
logger.info("Elasticsearch is disabled, skipping setup")
55+
return
56+
logger.info("Setting up Elasticsearch")
57+
self._elastic_search_config = self._setup_elastic_search()
58+
logger.info("Elasticsearch setup complete")
59+
60+
def _parse_elastic_env_file(self, env_file_path: str) -> ElasticSearchConfig:
61+
"""Parse ES_LOCAL_API_KEY and ES_LOCAL_URL from .env file."""
62+
loaded = load_dotenv(env_file_path)
63+
if not loaded:
64+
raise RuntimeError("Failed to load .env file")
65+
api_key = os.getenv("ES_LOCAL_API_KEY")
66+
url = os.getenv("ES_LOCAL_URL")
67+
if not url or not api_key:
68+
raise RuntimeError("Failed to parse ES_LOCAL_API_KEY and ES_LOCAL_URL from .env file")
69+
return ElasticSearchConfig(url=url, api_key=api_key)
70+
71+
def _setup_elastic_search(self) -> ElasticSearchConfig:
72+
eval_protocol_dir = find_eval_protocol_dir()
73+
elastic_start_local_dir = os.path.join(eval_protocol_dir, "elastic-start-local")
74+
env_file_path = os.path.join(elastic_start_local_dir, ".env")
75+
76+
# if elastic-start-local directory exists, return the config
77+
if os.path.exists(elastic_start_local_dir):
78+
# run start.sh in the elastic-start-local directory
79+
from eval_protocol.utils.subprocess_utils import run_script_and_wait
80+
81+
run_script_and_wait(
82+
script_name="start.sh",
83+
working_directory=elastic_start_local_dir,
84+
inherit_stdout=True,
85+
)
86+
return self._parse_elastic_env_file(env_file_path)
87+
88+
# run Elasticsearch start-local script: "curl -fsSL https://elastic.co/start-local | sh -s -- --esonly"
89+
process = subprocess.Popen(
90+
["sh", "-c", "curl -fsSL https://elastic.co/start-local | sh -s -- --esonly"],
91+
cwd=eval_protocol_dir,
92+
stdout=subprocess.PIPE,
93+
stderr=subprocess.STDOUT,
94+
text=True,
95+
)
96+
returncode = process.wait()
97+
if returncode != 0:
98+
raise RuntimeError("Failed to start Elasticsearch")
99+
100+
return self._parse_elastic_env_file(env_file_path)
40101

41102
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
42103
tasks: List[asyncio.Task[EvaluationRow]] = []
@@ -119,6 +180,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
119180
tools=row.tools,
120181
metadata=meta,
121182
model_base_url=model_base_url,
183+
elastic_search_config=self._elastic_search_config,
122184
)
123185

124186
# Fire-and-poll

eval_protocol/pytest/rollout_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class RolloutProcessor(ABC):
1010
Abstract base class for all rollout processor strategies.
1111
"""
1212

13+
def setup(self) -> None:
14+
"""Setup resources. Override in subclasses if setup is needed. Executed once per invocation."""
15+
pass
16+
1317
@abstractmethod
1418
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
1519
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""

eval_protocol/types/remote_rollout_processor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
from eval_protocol.models import Message, Status
88

99

10+
class ElasticSearchConfig(BaseModel):
11+
"""
12+
Configuration for Elasticsearch.
13+
"""
14+
15+
url: str
16+
api_key: str
17+
18+
1019
class RolloutMetadata(BaseModel):
1120
"""Metadata for rollout execution."""
1221

@@ -21,6 +30,7 @@ class InitRequest(BaseModel):
2130
"""Request model for POST /init endpoint."""
2231

2332
model: str
33+
elastic_search_config: Optional[ElasticSearchConfig] = None
2434
messages: Optional[List[Message]] = None
2535
tools: Optional[List[Dict[str, Any]]] = None
2636

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Cross-platform subprocess utilities for running scripts and commands."""
2+
3+
import os
4+
import platform
5+
import subprocess
6+
from typing import Optional
7+
8+
9+
def run_script_cross_platform(
10+
script_name: str,
11+
working_directory: str,
12+
capture_output: bool = True,
13+
print_output: bool = False,
14+
inherit_stdout: bool = False,
15+
) -> subprocess.Popen:
16+
"""
17+
Run a script in a cross-platform manner.
18+
19+
Args:
20+
script_name: Name of the script to run (e.g., "start.sh")
21+
working_directory: Directory to run the script in
22+
capture_output: Whether to capture stdout/stderr
23+
print_output: Whether to print output in real-time
24+
inherit_stdout: Whether to inherit stdout from parent process
25+
26+
Returns:
27+
subprocess.Popen object for the running process
28+
29+
Raises:
30+
RuntimeError: If the script fails to start or execute
31+
"""
32+
script_path = os.path.join(working_directory, script_name)
33+
34+
if not os.path.exists(script_path):
35+
raise FileNotFoundError(f"Script not found: {script_path}")
36+
37+
# Determine stdout handling
38+
if inherit_stdout:
39+
stdout = None # Inherit from parent process
40+
stderr = subprocess.STDOUT # Still capture stderr
41+
elif capture_output:
42+
stdout = subprocess.PIPE
43+
stderr = subprocess.STDOUT
44+
else:
45+
stdout = None
46+
stderr = None
47+
48+
if platform.system() == "Windows":
49+
# On Windows, use cmd.exe to run the script
50+
cmd = ["cmd.exe", "/c", script_name]
51+
process = subprocess.Popen(
52+
cmd,
53+
cwd=working_directory,
54+
stdout=stdout,
55+
stderr=stderr,
56+
text=True,
57+
)
58+
else:
59+
# On Unix-like systems, make executable and run with proper shebang
60+
os.chmod(script_path, 0o755)
61+
62+
# Use the full path to the script with shell=True
63+
process = subprocess.Popen(
64+
script_path,
65+
stdout=stdout,
66+
stderr=stderr,
67+
text=True,
68+
shell=True,
69+
)
70+
71+
# Print output in real-time if requested
72+
if print_output and capture_output and process.stdout:
73+
for line in process.stdout:
74+
print(line, end="")
75+
76+
return process
77+
78+
79+
def run_script_and_wait(
80+
script_name: str,
81+
working_directory: str,
82+
print_output: bool = False,
83+
inherit_stdout: bool = False,
84+
timeout: Optional[int] = None,
85+
) -> int:
86+
"""
87+
Run a script and wait for it to complete.
88+
89+
Args:
90+
script_name: Name of the script to run
91+
working_directory: Directory to run the script in
92+
print_output: Whether to print output in real-time
93+
inherit_stdout: Whether to inherit stdout from parent process
94+
timeout: Maximum time to wait for the script to complete
95+
96+
Returns:
97+
Return code of the script
98+
99+
Raises:
100+
RuntimeError: If the script fails to execute
101+
subprocess.TimeoutExpired: If the script times out
102+
"""
103+
process = run_script_cross_platform(
104+
script_name=script_name,
105+
working_directory=working_directory,
106+
capture_output=print_output and not inherit_stdout,
107+
print_output=print_output,
108+
inherit_stdout=inherit_stdout,
109+
)
110+
111+
try:
112+
returncode = process.wait(timeout=timeout)
113+
if returncode != 0:
114+
raise RuntimeError(f"Script '{script_name}' failed with return code {returncode}")
115+
return returncode
116+
except subprocess.TimeoutExpired:
117+
process.kill()
118+
raise

0 commit comments

Comments
 (0)