Skip to content

Commit 63d0fb5

Browse files
Benny ChenBenny Chen
authored andcommitted
fix more errors
1 parent b57b87e commit 63d0fb5

File tree

6 files changed

+42
-21
lines changed

6 files changed

+42
-21
lines changed

eval_protocol/integrations/deepeval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def _build_case_kwargs() -> Dict[str, Any]:
9696
case_kwargs = _build_case_kwargs()
9797
test_case = LLMTestCase(**case_kwargs)
9898

99-
metric.measure(test_case, **kwargs)
99+
# Guard against metric.measure being None or non-callable
100+
measure_fn = getattr(metric, "measure", None)
101+
if not callable(measure_fn):
102+
raise TypeError("Provided metric does not have a callable 'measure' method")
103+
measure_fn(test_case, **kwargs)
100104
score = float(metric.score or 0.0)
101105
reason = getattr(metric, "reason", None)
102106
name = _metric_name(metric)

eval_protocol/mcp_agent/orchestration/local_docker_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ async def startup(self) -> None:
5757
except docker.errors.DockerException as e:
5858
logger.warning(f"docker.from_env() failed: {e}. Trying explicit base_url.")
5959
try:
60+
# docker.from_env is preferred, but as a fallback use DockerClient with url param name 'base_url'
6061
self.docker_client = docker.DockerClient(base_url="unix://var/run/docker.sock")
6162
if not self.docker_client.ping(): # type: ignore
6263
raise ConnectionError("Failed to connect to Docker daemon with explicit base_url.")
@@ -649,7 +650,7 @@ async def list_tools_on_instance(self, instance: ManagedInstanceInfo) -> types.L
649650
)
650651
target_base_url = instance.mcp_endpoint_url.rstrip("/")
651652
try:
652-
async with streamablehttp_client(base_url=target_base_url) as (
653+
async with streamablehttp_client(url=target_base_url) as (
653654
read_s,
654655
write_s,
655656
_, # get_session_id_func usually not needed for a single call

eval_protocol/mcp_servers/tau2/tau2_mcp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, seed: Optional[int] = None, **kwargs):
4343

4444
self.adapter = EnvironmentAdapter(env_class=AirlineEnvironment, default_config=default_config)
4545

46+
# Ensure name is a str and not None
4647
super().__init__("airline", self.adapter, seed, **kwargs)
4748

4849
def _register_tools(self):

eval_protocol/pytest/plugin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ def pytest_sessionfinish(session, exitstatus):
309309
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
310310

311311
print("=" * 80, file=sys.__stderr__)
312-
sys.__stderr__.flush()
312+
err_stream = getattr(sys, "__stderr__", None)
313+
if err_stream is not None:
314+
try:
315+
err_stream.flush() # type: ignore[attr-defined]
316+
except Exception:
317+
pass
313318
except Exception:
314319
pass

eval_protocol/rewards/json_schema.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,13 @@ def json_schema_reward_with_llm_judge(
342342
if messages:
343343
conversation_parts = []
344344
for msg in messages[:-1]:
345-
role = msg.get("role", "")
346-
content_part = msg.get("content", "")
345+
if isinstance(msg, dict):
346+
role = msg.get("role", "")
347+
content_part = msg.get("content", "")
348+
else:
349+
# Fallback for Message objects
350+
role = getattr(msg, "role", "")
351+
content_part = getattr(msg, "content", "")
347352
if role and content_part:
348353
conversation_parts.append(f"{role}: {content_part}")
349354
if conversation_parts:

eval_protocol/rewards/math.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88

99
import math
1010
import re
11-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
11+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
1212

1313
from ..models import EvaluateResult, Message, MetricResult
1414
from ..typed_interface import reward_function
1515

16+
# Types used throughout this module to clearly express allowed answer values.
17+
# Include both float and int since extraction may yield either at analysis time.
18+
Numeric = Union[int, float]
19+
AnswerValue = Union[Numeric, str]
20+
1621
_ALGEBRAIC_VARS_SET: Set[str] = {
1722
"x",
1823
"y",
@@ -78,9 +83,9 @@ def _is_coefficient(
7883
return False
7984

8085

81-
def _extract_html_tag_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
86+
def _extract_html_tag_answers(text: str) -> List[Tuple[str, AnswerValue]]:
8287
"""Extracts answers from <answer> or <ans> HTML-like tags."""
83-
html_tag_answers: List[Tuple[str, Union[float, str]]] = []
88+
html_tag_answers: List[Tuple[str, AnswerValue]] = []
8489
tag_re = re.compile(
8590
r"<(?P<tag>answer|ans)\b[^>]*>(?P<inner>.*?)</(?P=tag)>",
8691
re.IGNORECASE | re.DOTALL,
@@ -126,12 +131,12 @@ def _extract_html_tag_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
126131

127132
def _extract_boxed_latex_answers(
128133
text: str,
129-
) -> Tuple[List[Tuple[str, Union[float, str]]], bool]:
134+
) -> Tuple[List[Tuple[str, AnswerValue]], bool]:
130135
"""
131136
Extracts answers from \\boxed{} LaTeX expressions.
132137
Returns a tuple: (list of answers, boolean indicating if any boxed expr was found).
133138
"""
134-
boxed_answers: List[Tuple[str, Union[float, str]]] = []
139+
boxed_answers: List[Tuple[str, AnswerValue]] = []
135140
found_any_boxed_expr = False
136141
for m_boxed in re.finditer(r"\\boxed\s*\{\s*((?:[^{}]|\{[^{}]*\})*?)\s*\}", text):
137142
found_any_boxed_expr = True
@@ -192,7 +197,7 @@ def _extract_boxed_latex_answers(
192197
return boxed_answers, found_any_boxed_expr
193198

194199

195-
def extract_numbers(text: str) -> List[Tuple[str, Union[float, str]]]:
200+
def extract_numbers(text: str) -> List[Tuple[str, AnswerValue]]:
196201
"""
197202
Extracts mathematical answers from text based on a hierarchical priority:
198203
1. HTML <answer>/<ans> tags
@@ -228,7 +233,7 @@ def extract_numbers(text: str) -> List[Tuple[str, Union[float, str]]]:
228233
return []
229234

230235

231-
def _extract_gsm8k_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
236+
def _extract_gsm8k_answers(text: str) -> List[Tuple[str, AnswerValue]]:
232237
"""Extracts answers from GSM8K-style final answer markers (#### ...)."""
233238
final_marker_answers: List[Tuple[str, Union[float, str]]] = []
234239
GSM8K_NUM_CONTENT_PATTERN = r"-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?"
@@ -243,7 +248,7 @@ def _extract_gsm8k_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
243248
return final_marker_answers
244249

245250

246-
def _extract_general_numeric_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
251+
def _extract_general_numeric_answers(text: str) -> List[Tuple[str, AnswerValue]]:
247252
"""Extracts general numeric or LaTeX-formatted numbers as a fallback."""
248253
potential_general_matches: List[Dict[str, Any]] = []
249254

@@ -399,7 +404,7 @@ def _extract_general_numeric_answers(text: str) -> List[Tuple[str, Union[float,
399404
pass
400405

401406
potential_general_matches.sort(key=lambda x: (x["span"][0], -(x["span"][1] - x["span"][0]), x["type_priority"]))
402-
filtered_general_answers: List[Tuple[str, Union[float, str]]] = []
407+
filtered_general_answers: List[Tuple[str, AnswerValue]] = []
403408
last_covered_end = -1
404409
for item in potential_general_matches:
405410
start, end = item["span"]
@@ -461,7 +466,7 @@ def _has_unit_text(full_extracted_text: str, numeric_value: float) -> bool:
461466

462467
def _check_unboxed_or_strictness(
463468
model_response_content: str,
464-
gen_answers_extracted: List[Tuple[str, Union[float, str]]],
469+
gen_answers_extracted: Sequence[Tuple[str, AnswerValue]],
465470
metrics: Dict[str, MetricResult],
466471
) -> Optional[EvaluateResult]:
467472
"""Checks for 'unboxed or' strictness violation."""
@@ -487,8 +492,8 @@ def _check_unboxed_or_strictness(
487492

488493

489494
def _check_ambiguity_strictness(
490-
orig_answers_extracted: List[Tuple[str, Union[float, str]]],
491-
gen_answers_extracted: List[Tuple[str, Union[float, str]]],
495+
orig_answers_extracted: Sequence[Tuple[str, AnswerValue]],
496+
gen_answers_extracted: Sequence[Tuple[str, AnswerValue]],
492497
metrics: Dict[str, MetricResult],
493498
) -> Optional[EvaluateResult]:
494499
"""Checks for ambiguity strictness violation."""
@@ -503,8 +508,8 @@ def _check_ambiguity_strictness(
503508

504509

505510
def _check_conflicting_answers_strictness(
506-
orig_answers_extracted: List[Tuple[str, Union[float, str]]],
507-
gen_answers_extracted: List[Tuple[str, Union[float, str]]],
511+
orig_answers_extracted: Sequence[Tuple[str, AnswerValue]],
512+
gen_answers_extracted: Sequence[Tuple[str, AnswerValue]],
508513
best_match_score: float,
509514
match_found_flag: bool,
510515
is_single_orig_boxed_truth: bool,
@@ -603,7 +608,7 @@ def math_reward(
603608

604609
gen_answers_extracted_initial = extract_numbers(model_response_content)
605610
orig_answers_extracted = extract_numbers(ground_truth)
606-
gen_answers_extracted = list(gen_answers_extracted_initial)
611+
gen_answers_extracted: List[Tuple[str, AnswerValue]] = list(gen_answers_extracted_initial)
607612
metrics: Dict[str, MetricResult] = {}
608613

609614
def format_extracted(items: List[Tuple[str, Union[float, str]]]) -> str:
@@ -654,7 +659,7 @@ def format_extracted(items: List[Tuple[str, Union[float, str]]]) -> str:
654659
abs_tol=absolute_tolerance,
655660
):
656661
has_matching_gen_boxed_answer = True
657-
gen_answers_extracted = [(gen_text, gen_val)]
662+
gen_answers_extracted = [(gen_text, cast(AnswerValue, gen_val))]
658663
metrics["demo_leniency_info"] = MetricResult(
659664
score=1.0,
660665
is_score_valid=True,

0 commit comments

Comments
 (0)