88
99import math
1010import re
11- from typing import Any , Dict , List , Optional , Set , Tuple , Union
11+ from typing import Any , Dict , List , Optional , Sequence , Set , Tuple , Union , cast
1212
1313from ..models import EvaluateResult , Message , MetricResult
1414from ..typed_interface import reward_function
1515
16+ # Types used throughout this module to clearly express allowed answer values.
17+ # Include both float and int since extraction may yield either at analysis time.
18+ Numeric = Union [int , float ]
19+ AnswerValue = Union [Numeric , str ]
20+
1621_ALGEBRAIC_VARS_SET : Set [str ] = {
1722 "x" ,
1823 "y" ,
@@ -78,9 +83,9 @@ def _is_coefficient(
7883 return False
7984
8085
81- def _extract_html_tag_answers (text : str ) -> List [Tuple [str , Union [ float , str ] ]]:
86+ def _extract_html_tag_answers (text : str ) -> List [Tuple [str , AnswerValue ]]:
8287 """Extracts answers from <answer> or <ans> HTML-like tags."""
83- html_tag_answers : List [Tuple [str , Union [ float , str ] ]] = []
88+ html_tag_answers : List [Tuple [str , AnswerValue ]] = []
8489 tag_re = re .compile (
8590 r"<(?P<tag>answer|ans)\b[^>]*>(?P<inner>.*?)</(?P=tag)>" ,
8691 re .IGNORECASE | re .DOTALL ,
@@ -126,12 +131,12 @@ def _extract_html_tag_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
126131
127132def _extract_boxed_latex_answers (
128133 text : str ,
129- ) -> Tuple [List [Tuple [str , Union [ float , str ] ]], bool ]:
134+ ) -> Tuple [List [Tuple [str , AnswerValue ]], bool ]:
130135 """
131136 Extracts answers from \\ boxed{} LaTeX expressions.
132137 Returns a tuple: (list of answers, boolean indicating if any boxed expr was found).
133138 """
134- boxed_answers : List [Tuple [str , Union [ float , str ] ]] = []
139+ boxed_answers : List [Tuple [str , AnswerValue ]] = []
135140 found_any_boxed_expr = False
136141 for m_boxed in re .finditer (r"\\boxed\s*\{\s*((?:[^{}]|\{[^{}]*\})*?)\s*\}" , text ):
137142 found_any_boxed_expr = True
@@ -192,7 +197,7 @@ def _extract_boxed_latex_answers(
192197 return boxed_answers , found_any_boxed_expr
193198
194199
195- def extract_numbers (text : str ) -> List [Tuple [str , Union [ float , str ] ]]:
200+ def extract_numbers (text : str ) -> List [Tuple [str , AnswerValue ]]:
196201 """
197202 Extracts mathematical answers from text based on a hierarchical priority:
198203 1. HTML <answer>/<ans> tags
@@ -228,7 +233,7 @@ def extract_numbers(text: str) -> List[Tuple[str, Union[float, str]]]:
228233 return []
229234
230235
231- def _extract_gsm8k_answers (text : str ) -> List [Tuple [str , Union [ float , str ] ]]:
236+ def _extract_gsm8k_answers (text : str ) -> List [Tuple [str , AnswerValue ]]:
232237 """Extracts answers from GSM8K-style final answer markers (#### ...)."""
233238 final_marker_answers : List [Tuple [str , Union [float , str ]]] = []
234239 GSM8K_NUM_CONTENT_PATTERN = r"-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?"
@@ -243,7 +248,7 @@ def _extract_gsm8k_answers(text: str) -> List[Tuple[str, Union[float, str]]]:
243248 return final_marker_answers
244249
245250
246- def _extract_general_numeric_answers (text : str ) -> List [Tuple [str , Union [ float , str ] ]]:
251+ def _extract_general_numeric_answers (text : str ) -> List [Tuple [str , AnswerValue ]]:
247252 """Extracts general numeric or LaTeX-formatted numbers as a fallback."""
248253 potential_general_matches : List [Dict [str , Any ]] = []
249254
@@ -399,7 +404,7 @@ def _extract_general_numeric_answers(text: str) -> List[Tuple[str, Union[float,
399404 pass
400405
401406 potential_general_matches .sort (key = lambda x : (x ["span" ][0 ], - (x ["span" ][1 ] - x ["span" ][0 ]), x ["type_priority" ]))
402- filtered_general_answers : List [Tuple [str , Union [ float , str ] ]] = []
407+ filtered_general_answers : List [Tuple [str , AnswerValue ]] = []
403408 last_covered_end = - 1
404409 for item in potential_general_matches :
405410 start , end = item ["span" ]
@@ -461,7 +466,7 @@ def _has_unit_text(full_extracted_text: str, numeric_value: float) -> bool:
461466
462467def _check_unboxed_or_strictness (
463468 model_response_content : str ,
464- gen_answers_extracted : List [Tuple [str , Union [ float , str ] ]],
469+ gen_answers_extracted : Sequence [Tuple [str , AnswerValue ]],
465470 metrics : Dict [str , MetricResult ],
466471) -> Optional [EvaluateResult ]:
467472 """Checks for 'unboxed or' strictness violation."""
@@ -487,8 +492,8 @@ def _check_unboxed_or_strictness(
487492
488493
489494def _check_ambiguity_strictness (
490- orig_answers_extracted : List [Tuple [str , Union [ float , str ] ]],
491- gen_answers_extracted : List [Tuple [str , Union [ float , str ] ]],
495+ orig_answers_extracted : Sequence [Tuple [str , AnswerValue ]],
496+ gen_answers_extracted : Sequence [Tuple [str , AnswerValue ]],
492497 metrics : Dict [str , MetricResult ],
493498) -> Optional [EvaluateResult ]:
494499 """Checks for ambiguity strictness violation."""
@@ -503,8 +508,8 @@ def _check_ambiguity_strictness(
503508
504509
505510def _check_conflicting_answers_strictness (
506- orig_answers_extracted : List [Tuple [str , Union [ float , str ] ]],
507- gen_answers_extracted : List [Tuple [str , Union [ float , str ] ]],
511+ orig_answers_extracted : Sequence [Tuple [str , AnswerValue ]],
512+ gen_answers_extracted : Sequence [Tuple [str , AnswerValue ]],
508513 best_match_score : float ,
509514 match_found_flag : bool ,
510515 is_single_orig_boxed_truth : bool ,
@@ -603,7 +608,7 @@ def math_reward(
603608
604609 gen_answers_extracted_initial = extract_numbers (model_response_content )
605610 orig_answers_extracted = extract_numbers (ground_truth )
606- gen_answers_extracted = list (gen_answers_extracted_initial )
611+ gen_answers_extracted : List [ Tuple [ str , AnswerValue ]] = list (gen_answers_extracted_initial )
607612 metrics : Dict [str , MetricResult ] = {}
608613
609614 def format_extracted (items : List [Tuple [str , Union [float , str ]]]) -> str :
@@ -654,7 +659,7 @@ def format_extracted(items: List[Tuple[str, Union[float, str]]]) -> str:
654659 abs_tol = absolute_tolerance ,
655660 ):
656661 has_matching_gen_boxed_answer = True
657- gen_answers_extracted = [(gen_text , gen_val )]
662+ gen_answers_extracted = [(gen_text , cast ( AnswerValue , gen_val ) )]
658663 metrics ["demo_leniency_info" ] = MetricResult (
659664 score = 1.0 ,
660665 is_score_valid = True ,
0 commit comments