@@ -128,7 +128,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
128128 server_script_path = ep_config .get ("server_script_path" )
129129 steps = ep_config .get ("steps" )
130130 mode = ep_config .get ("mode" )
131- combine_datasets = ep_config . get ( "combine_datasets" )
131+ # combine_datasets captured but not used here
132132
133133 # Choose the first rollout param set by default
134134 rollout_params = None
@@ -172,3 +172,161 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
172172 return _decorator
173173
174174
175+ def register_composite_benchmark (name : str , children : List [str ]) -> None :
176+ """
177+ Register a composite benchmark that runs multiple exported benchmarks and aggregates results.
178+
179+ The composite runner forwards common overrides to each child benchmark and aggregates
180+ a combined score as a rows-weighted mean of each child's aggregated score.
181+
182+ Args:
183+ name: Name of the composite benchmark to register.
184+ children: List of child benchmark names previously registered via export_benchmark.
185+ """
186+
187+ def _composite_runner (
188+ * ,
189+ model : Optional [str ] = None ,
190+ print_summary : bool = False ,
191+ out : Optional [str ] = None ,
192+ reasoning_effort : Optional [str ] = None ,
193+ max_rows : Optional [int | str ] = None ,
194+ num_runs : Optional [int ] = None ,
195+ input_params_override : Optional [Dict [str , Any ]] = None ,
196+ max_concurrency : Optional [int ] = None ,
197+ ) -> Dict [str , Any ]:
198+ # Resolve child runners at call-time to ensure all suites are imported
199+ # Local import avoided to prevent circular import at module import time
200+ _get_benchmark_runner = get_benchmark_runner
201+ import pathlib as _pathlib
202+ import time as _time
203+ _json = json
204+
205+ child_summaries : List [Dict [str , Any ]] = []
206+ total_rows = 0
207+ weighted_sum = 0.0
208+ # For per-metric aggregation across children
209+ metric_weighted_sums : Dict [str , float ] = {}
210+ metric_total_rows : Dict [str , int ] = {}
211+ combined_rows : List [Any ] = []
212+
213+ # If 'out' is a file path, also compute a directory for child artifacts
214+ child_out_dir : Optional [str ] = None
215+ if out :
216+ p = _pathlib .Path (out )
217+ if p .suffix .lower () == ".json" and not str (out ).endswith ("/" ):
218+ # Use parent directory for child artifacts
219+ child_out_dir = str (p .parent )
220+ else :
221+ child_out_dir = out
222+
223+ for child_name in children :
224+ runner = _get_benchmark_runner (child_name )
225+ result = runner (
226+ model = model ,
227+ print_summary = print_summary ,
228+ out = child_out_dir ,
229+ reasoning_effort = reasoning_effort ,
230+ max_rows = max_rows ,
231+ num_runs = num_runs ,
232+ input_params_override = input_params_override ,
233+ max_concurrency = max_concurrency ,
234+ )
235+ summary = (result or {}).get ("summary" ) if isinstance (result , dict ) else None
236+ if not summary :
237+ continue
238+ # Gather underlying rows to recompute CI across children
239+ try :
240+ rows_obj = result .get ("results" ) if isinstance (result , dict ) else None
241+ if isinstance (rows_obj , list ):
242+ combined_rows .extend (rows_obj )
243+ except Exception :
244+ pass
245+ child_summaries .append (summary )
246+ rows = int (summary .get ("rows" , 0 ) or 0 )
247+ agg = summary .get ("agg_score" )
248+ if isinstance (agg , (int , float )) and rows > 0 :
249+ total_rows += rows
250+ weighted_sum += float (agg ) * rows
251+ # Combine per-metric means if available
252+ metrics_agg = summary .get ("metrics_agg" ) or {}
253+ if isinstance (metrics_agg , dict ):
254+ for m_name , m_vals in metrics_agg .items ():
255+ m_mean = m_vals .get ("mean" )
256+ if isinstance (m_mean , (int , float )) and rows > 0 :
257+ metric_weighted_sums [m_name ] = metric_weighted_sums .get (m_name , 0.0 ) + float (m_mean ) * rows
258+ metric_total_rows [m_name ] = metric_total_rows .get (m_name , 0 ) + rows
259+
260+ combined_agg = (weighted_sum / total_rows ) if total_rows > 0 else None
261+ # Compute 95% CI for combined rows if available
262+ ci_low : Optional [float ] = None
263+ ci_high : Optional [float ] = None
264+ if combined_rows :
265+ try :
266+ from eval_protocol .stats .confidence_intervals import compute_fixed_set_mu_ci as _compute_ci
267+
268+ r = _compute_ci (combined_rows )
269+ if r and len (r ) >= 3 and r [1 ] is not None and r [2 ] is not None :
270+ ci_low = float (r [1 ])
271+ ci_high = float (r [2 ])
272+ except Exception :
273+ ci_low = None
274+ ci_high = None
275+ combined_metrics : Dict [str , Dict [str , float ]] = {}
276+ for m_name , wsum in metric_weighted_sums .items ():
277+ denom = metric_total_rows .get (m_name , 0 )
278+ if denom > 0 :
279+ combined_metrics [m_name ] = {"mean" : float (wsum / denom )}
280+ combined = {
281+ "suite" : name ,
282+ "model" : model ,
283+ "agg_score" : float (combined_agg ) if combined_agg is not None else None ,
284+ "rows" : total_rows ,
285+ "children" : child_summaries ,
286+ "num_runs" : num_runs ,
287+ ** ({"metrics_agg" : combined_metrics } if combined_metrics else {}),
288+ ** ({"agg_ci_low" : ci_low , "agg_ci_high" : ci_high } if (ci_low is not None and ci_high is not None ) else {}),
289+ }
290+
291+ # Optional print and persist
292+ # Respect either function arg or EP_PRINT_SUMMARY env
293+ _should_print = print_summary or (os .getenv ("EP_PRINT_SUMMARY" ) == "1" )
294+ if _should_print :
295+ try :
296+ if combined_agg is not None :
297+ if ci_low is not None and ci_high is not None :
298+ print (
299+ f"EP Summary | suite={ name } model={ model } agg={ combined ['agg_score' ]:.3f} ci95=[{ ci_low :.3f} ,{ ci_high :.3f} ] rows={ total_rows } "
300+ )
301+ else :
302+ print (
303+ f"EP Summary | suite={ name } model={ model } agg={ combined ['agg_score' ]:.3f} rows={ total_rows } "
304+ )
305+ else :
306+ print (
307+ f"EP Summary | suite={ name } model={ model } agg=None rows={ total_rows } "
308+ )
309+ except Exception :
310+ pass
311+
312+ if out :
313+ out_path = _pathlib .Path (out )
314+ if out_path .suffix .lower () == ".json" and not str (out ).endswith ("/" ):
315+ # Write to the specified file
316+ out_path .parent .mkdir (parents = True , exist_ok = True )
317+ with open (out_path , "w" , encoding = "utf-8" ) as f :
318+ _json .dump ({** combined , "timestamp" : int (_time .time ())}, f )
319+ else :
320+ # Treat as directory
321+ dir_path = out_path
322+ dir_path .mkdir (parents = True , exist_ok = True )
323+ safe_name = name .replace ("/" , "__" )
324+ file_path = dir_path / f"{ safe_name } __composite.json"
325+ with open (file_path , "w" , encoding = "utf-8" ) as f :
326+ _json .dump ({** combined , "timestamp" : int (_time .time ())}, f )
327+
328+ return {"summary" : combined }
329+
330+ # Register (overwrite if exists)
331+ _BENCHMARK_REGISTRY [name ] = _composite_runner
332+
0 commit comments