@@ -126,7 +126,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
126126 server_script_path = ep_config .get ("server_script_path" )
127127 steps = ep_config .get ("steps" )
128128 mode = ep_config .get ("mode" )
129- combine_datasets = ep_config . get ( "combine_datasets" )
129+ # combine_datasets captured but not used here
130130
131131 # Choose the first rollout param set by default
132132 rollout_params = None
@@ -169,3 +169,162 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
169169 return test_wrapper
170170
171171 return _decorator
172+
173+
174+ def register_composite_benchmark (name : str , children : List [str ]) -> None :
175+ """
176+ Register a composite benchmark that runs multiple exported benchmarks and aggregates results.
177+
178+ The composite runner forwards common overrides to each child benchmark and aggregates
179+ a combined score as a rows-weighted mean of each child's aggregated score.
180+
181+ Args:
182+ name: Name of the composite benchmark to register.
183+ children: List of child benchmark names previously registered via export_benchmark.
184+ """
185+
186+ def _composite_runner (
187+ * ,
188+ model : Optional [str ] = None ,
189+ print_summary : bool = False ,
190+ out : Optional [str ] = None ,
191+ reasoning_effort : Optional [str ] = None ,
192+ max_rows : Optional [int | str ] = None ,
193+ num_runs : Optional [int ] = None ,
194+ input_params_override : Optional [Dict [str , Any ]] = None ,
195+ max_concurrency : Optional [int ] = None ,
196+ ) -> Dict [str , Any ]:
197+ # Resolve child runners at call-time to ensure all suites are imported
198+ # Local import avoided to prevent circular import at module import time
199+ _get_benchmark_runner = get_benchmark_runner
200+ import pathlib as _pathlib
201+ import time as _time
202+ _json = json
203+
204+ child_summaries : List [Dict [str , Any ]] = []
205+ total_rows = 0
206+ weighted_sum = 0.0
207+ # For per-metric aggregation across children
208+ metric_weighted_sums : Dict [str , float ] = {}
209+ metric_total_rows : Dict [str , int ] = {}
210+ combined_rows : List [Any ] = []
211+
212+ # If 'out' is a file path, also compute a directory for child artifacts
213+ child_out_dir : Optional [str ] = None
214+ if out :
215+ p = _pathlib .Path (out )
216+ if p .suffix .lower () == ".json" and not str (out ).endswith ("/" ):
217+ # Use parent directory for child artifacts
218+ child_out_dir = str (p .parent )
219+ else :
220+ child_out_dir = out
221+
222+ for child_name in children :
223+ runner = _get_benchmark_runner (child_name )
224+ result = runner (
225+ model = model ,
226+ print_summary = print_summary ,
227+ out = child_out_dir ,
228+ reasoning_effort = reasoning_effort ,
229+ max_rows = max_rows ,
230+ num_runs = num_runs ,
231+ input_params_override = input_params_override ,
232+ max_concurrency = max_concurrency ,
233+ )
234+ summary = (result or {}).get ("summary" ) if isinstance (result , dict ) else None
235+ if not summary :
236+ continue
237+ # Gather underlying rows to recompute CI across children
238+ try :
239+ rows_obj = result .get ("results" ) if isinstance (result , dict ) else None
240+ if isinstance (rows_obj , list ):
241+ combined_rows .extend (rows_obj )
242+ except Exception :
243+ pass
244+ child_summaries .append (summary )
245+ rows = int (summary .get ("rows" , 0 ) or 0 )
246+ agg = summary .get ("agg_score" )
247+ if isinstance (agg , (int , float )) and rows > 0 :
248+ total_rows += rows
249+ weighted_sum += float (agg ) * rows
250+ # Combine per-metric means if available
251+ metrics_agg = summary .get ("metrics_agg" ) or {}
252+ if isinstance (metrics_agg , dict ):
253+ for m_name , m_vals in metrics_agg .items ():
254+ m_mean = m_vals .get ("mean" )
255+ if isinstance (m_mean , (int , float )) and rows > 0 :
256+ metric_weighted_sums [m_name ] = metric_weighted_sums .get (m_name , 0.0 ) + float (m_mean ) * rows
257+ metric_total_rows [m_name ] = metric_total_rows .get (m_name , 0 ) + rows
258+
259+ combined_agg = (weighted_sum / total_rows ) if total_rows > 0 else None
260+ # Compute 95% CI for combined rows if available
261+ ci_low : Optional [float ] = None
262+ ci_high : Optional [float ] = None
263+ if combined_rows :
264+ try :
265+ from eval_protocol .stats .confidence_intervals import compute_fixed_set_mu_ci as _compute_ci
266+
267+ r = _compute_ci (combined_rows )
268+ if r and len (r ) >= 3 and r [1 ] is not None and r [2 ] is not None :
269+ ci_low = float (r [1 ])
270+ ci_high = float (r [2 ])
271+ except Exception :
272+ ci_low = None
273+ ci_high = None
274+ combined_metrics : Dict [str , Dict [str , float ]] = {}
275+ for m_name , wsum in metric_weighted_sums .items ():
276+ denom = metric_total_rows .get (m_name , 0 )
277+ if denom > 0 :
278+ combined_metrics [m_name ] = {"mean" : float (wsum / denom )}
279+ combined = {
280+ "suite" : name ,
281+ "model" : model ,
282+ "agg_score" : float (combined_agg ) if combined_agg is not None else None ,
283+ "rows" : total_rows ,
284+ "children" : child_summaries ,
285+ "num_runs" : num_runs ,
286+ ** ({"metrics_agg" : combined_metrics } if combined_metrics else {}),
287+ ** ({"agg_ci_low" : ci_low , "agg_ci_high" : ci_high } if (ci_low is not None and ci_high is not None ) else {}),
288+ }
289+
290+ # Optional print and persist
291+ # Respect either function arg or EP_PRINT_SUMMARY env
292+ _should_print = print_summary or (os .getenv ("EP_PRINT_SUMMARY" ) == "1" )
293+ if _should_print :
294+ try :
295+ if combined_agg is not None :
296+ if ci_low is not None and ci_high is not None :
297+ print (
298+ f"EP Summary | suite={ name } model={ model } agg={ combined ['agg_score' ]:.3f} ci95=[{ ci_low :.3f} ,{ ci_high :.3f} ] rows={ total_rows } "
299+ )
300+ else :
301+ print (
302+ f"EP Summary | suite={ name } model={ model } agg={ combined ['agg_score' ]:.3f} rows={ total_rows } "
303+ )
304+ else :
305+ print (
306+ f"EP Summary | suite={ name } model={ model } agg=None rows={ total_rows } "
307+ )
308+ except Exception :
309+ pass
310+
311+ if out :
312+ out_path = _pathlib .Path (out )
313+ if out_path .suffix .lower () == ".json" and not str (out ).endswith ("/" ):
314+ # Write to the specified file
315+ out_path .parent .mkdir (parents = True , exist_ok = True )
316+ with open (out_path , "w" , encoding = "utf-8" ) as f :
317+ _json .dump ({** combined , "timestamp" : int (_time .time ())}, f )
318+ else :
319+ # Treat as directory
320+ dir_path = out_path
321+ dir_path .mkdir (parents = True , exist_ok = True )
322+ safe_name = name .replace ("/" , "__" )
323+ file_path = dir_path / f"{ safe_name } __composite.json"
324+ with open (file_path , "w" , encoding = "utf-8" ) as f :
325+ _json .dump ({** combined , "timestamp" : int (_time .time ())}, f )
326+
327+ return {"summary" : combined }
328+
329+ # Register (overwrite if exists)
330+ _BENCHMARK_REGISTRY [name ] = _composite_runner
0 commit comments