2323from columnflow .util import dev_sandbox , DotDict
2424
2525
26- class CreateFakeFactorHistograms (
26+ class PrepareFakeFactorHistograms (
2727 VariablesMixin ,
2828 WeightProducerMixin ,
2929 ProducersMixin ,
@@ -177,8 +177,7 @@ def run(self):
177177
178178 h = (hist .Hist .new
179179 .IntCat ([], name = "category" , growth = True )
180- .IntCat ([], name = "process" , growth = True )
181- .IntCat ([], name = "shift" , growth = True ))
180+ .IntCat ([], name = "process" , growth = True ))
182181 for (var_name , var_axis ) in self .config_inst .x .fake_factor_method .axes .items ():
183182 h = eval (f'h.{ var_axis .ax_str } ' )
184183
@@ -189,11 +188,11 @@ def run(self):
189188 axis = - 1 ,
190189 )
191190 # broadcast arrays so that each event can be filled for all its categories
191+
192192 fill_data = {
193193 "category" : category_ids ,
194194 "process" : events .process_id ,
195- "shift" : np .ones (len (events ), dtype = np .int32 ) * self .global_shift_inst .id ,
196- "weight" : weight ,
195+ "weight" : weight ,
197196 }
198197 for (var_name , var_axis ) in self .config_inst .x .fake_factor_method .axes .items ():
199198 route = Route (var_axis .var_route )
@@ -214,19 +213,19 @@ def run(self):
214213
215214# overwrite class defaults
216215check_overlap_tasks = law .config .get_expanded ("analysis" , "check_overlapping_inputs" , [], split_csv = True )
217- CreateFakeFactorHistograms .check_overlapping_inputs = ChunkedIOMixin .check_overlapping_inputs .copy (
218- default = CreateFakeFactorHistograms .task_family in check_overlap_tasks ,
216+ PrepareFakeFactorHistograms .check_overlapping_inputs = ChunkedIOMixin .check_overlapping_inputs .copy (
217+ default = PrepareFakeFactorHistograms .task_family in check_overlap_tasks ,
219218 add_default_to_description = True ,
220219)
221220
222221
223- CreateFakeFactorHistogramsWrapper = wrapper_factory (
222+ PrepareFakeFactorHistogramsWrapper = wrapper_factory (
224223 base_cls = AnalysisTask ,
225- require_cls = CreateFakeFactorHistograms ,
224+ require_cls = PrepareFakeFactorHistograms ,
226225 enable = ["configs" , "skip_configs" , "datasets" , "skip_datasets" , "shifts" , "skip_shifts" ],
227226)
228227
229- class MergeFakeFactors (
228+ class ComputeFakeFactors (
230229 VariablesMixin ,
231230 DatasetsProcessesMixin ,
232231 CategoriesMixin ,
@@ -253,12 +252,12 @@ class MergeFakeFactors(
253252 # upstream requirements
254253 reqs = Requirements (
255254 RemoteWorkflow .reqs ,
256- CreateFakeFactorHistograms = CreateFakeFactorHistograms ,
255+ PrepareFakeFactorHistograms = PrepareFakeFactorHistograms ,
257256 )
258257
259258 def store_parts (self ):
260259 parts = super ().store_parts ()
261- parts .insert_before ("version" , "datasets" ) # , f"datasets_{self.datasets_repr}")
260+ parts .insert_before ("version" , "datasets" , f"datasets_{ self .datasets_repr } " )
262261 return parts
263262
264263 @classmethod
@@ -291,7 +290,7 @@ def workflow_requires(self):
291290 if not self .pilot :
292291 variables = self ._get_variables ()
293292 if variables :
294- reqs ["ff_method" ] = self .reqs .CreateFakeFactorHistograms .req_different_branching (
293+ reqs ["ff_method" ] = self .reqs .PrepareFakeFactorHistograms .req_different_branching (
295294 self ,
296295 branch = - 1 ,
297296 variables = tuple (variables ),
@@ -301,74 +300,141 @@ def workflow_requires(self):
301300
302301 def requires (self ):
303302 return {
304- d : self .reqs .CreateFakeFactorHistograms .req (
303+ d : self .reqs .PrepareFakeFactorHistograms .req (
305304 self ,
306305 dataset = d ,
307306 branch = - 1 ,
308307 )
309308 for d in self .datasets
310309 }
311310 def output (self ):
312- return {"hists" : self .target (f"fake_factors.pickle" )}
311+ return {"ff_json" : {ff_type : self .target (f"fake_factors_{ ff_type } .json" )for ff_type in ['qcd' ,'wj' ]},
312+ "plots" : {syst : self .target (f"fake_factor_syst_{ syst } .png" ) for syst in ['nominal' , 'up' , 'down' ]},}
313313
314314 @law .decorator .log
315315 def run (self ):
316316 import hist
317317 import numpy as np
318318 import matplotlib .pyplot as plt
319+ import correctionlib .convert as cl_convert
319320 # preare inputs and outputs
320321 inputs = self .input ()
321322 outputs = self .output ()
322323 merged_per_dataset = {}
323324 projected_hists = []
325+ hists_by_dataset = []
324326 for (dataset_name , dataset ) in inputs .items ():
325327 files = dataset ['collection' ]
326328 # load input histograms per dataset
327- hists = [
329+ hists_per_ds = [
328330 inp ['hists' ].load (formatter = "pickle" )['fake_factors' ]
329331 for inp in self .iter_progress (files .targets .values (), len (files ), reach = (0 , 50 ))
330332 ]
331333 self .publish_message (f"merging Fake factor histograms for { dataset_name } " )
332- the_hist = sum (hists [1 :], hists [0 ].copy ())
333- merged_per_dataset [dataset_name ] = the_hist
334- #Get axes names excluding 'process'. This is needed to merge hists for different processes
335- ax_names = [ax_name for ax_name in the_hist .axes .name if ax_name != 'process' ]
336- #Remove 'process' axis by projecting hist on the remaining axes
337- projected_hists .append (the_hist .project (* ax_names ))
338- merged_hist = sum (projected_hists [1 :], projected_hists [0 ].copy ())
334+ ds_single_hist = sum (hists_per_ds [1 :], hists_per_ds [0 ].copy ())
335+ hists_by_dataset .append (ds_single_hist )
336+
337+ hists_by_proc = {}
338+ for proc_name in self .config_inst .processes .names ():
339+ proc = self .config_inst .processes .get (proc_name )
340+ self .publish_message (f"merging Fake factor histograms for process: { proc .name } " )
341+ for the_hist in hists_by_dataset :
342+
343+ if proc .id in the_hist .axes ["process" ]:
344+ h = the_hist .copy ()
345+ h = h [{"process" : hist .loc (proc .id )}]
346+ # add the histogram
347+ if proc in hists_by_proc :
348+ hists_by_proc [proc ] += h
349+ else :
350+ hists_by_proc [proc ] = h
339351
340- cat_SR = self .config_inst .get_category (self .branch_data .category )
341- cat_DR_den = self .config_inst .get_category (cat_SR .x .DR_den )
342- cat_DR_num = self .config_inst .get_category (cat_SR .x .DR_num )
352+ mc_hists = [h for p , h in hists_by_proc .items () if p .is_mc and not p .has_tag ("signal" )]
353+ data_hists = [h for p , h in hists_by_proc .items () if p .is_data ]
343354
344- def get_hist ( h , category ):
345- return h [{ "category" : hist . loc ( category . id )}]
355+ mc_hists = sum ( mc_hists [ 1 :], mc_hists [ 0 ]. copy ())
356+ data_hists = sum ( data_hists [ 1 :], data_hists [ 0 ]. copy ())
346357
347- h_DR_num = get_hist (merged_hist ,cat_DR_num ).values ()
348- h_DR_den = get_hist (merged_hist ,cat_DR_den ).values ()
358+ dr_names = ['dr_num_wj' ,'dr_den_wj' ,'dr_num_qcd' ,'dr_den_qcd' ]
359+
360+ def get_hist (h , category ):
361+ return h [{"category" : hist .loc (category .id )}]
349362
350- ff_values = np .where ((h_DR_num > 0 ) & (h_DR_den > 0 ),
351- h_DR_num / np .maximum (h_DR_den , 1 ),
352- 0.0 ,
353- )
354363
355- #For the control: make 2d hists and plot them:
356- hist2d = merged_hist .project ('tau_pt' ,'tau_dm_pnet' )
357- ff_hist = hist .Hist (* hist2d .axes , data = ff_values [0 ])
358- fig , ax = plt .subplots (figsize = (12 , 8 ))
359- ff_hist .plot2d (ax = ax )
360- plt .savefig ('fake_factors.pdf' )
361- from IPython import embed ; embed ()
362- #outputs["hists"][variable_name].dump(merged, formatter="pickle")F
363-
364- # optionally remove inputs
365- if self .remove_previous :
366- inputs .remove ()
367-
368-
369- # MergeFakeFactorsWrapper = wrapper_factory(
370- # base_cls=AnalysisTask,
371- # require_cls=MergeFakeFactors,
372- # enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"],
373- # )
364+ #Create two dictionaries that contain histograms for different determination regions
365+ data_h_cat = {}
366+ mc_h_cat = {}
367+ for dr_name in dr_names :
368+ cat = self .config_inst .get_category (self .branch_data .category .replace ('sr' ,dr_name ))
369+ data_h_cat [dr_name ] = get_hist (data_hists , cat )
370+ mc_h_cat [dr_name ] = get_hist (mc_hists , cat )
371+
372+
373+ def get_ff_corr (self , h_data , h_mc , num_cat , den_cat , name = 'ff_hist' , label = 'ff_hist' ):
374+ num = h_data [num_cat ].values () - h_mc [num_cat ].values ()
375+ den = h_data [den_cat ].values () - h_mc [den_cat ].values ()
376+ ff_val = np .where ((num > 0 ) & (den > 0 ),
377+ num / np .maximum (den , 1 ),
378+ 1 )
379+ def rel_err (x ):
380+ return x .variances ()/ np .maximum (x .values ()** 2 , 1 )
381+ ff_err2 = np .where ((num > 0 ) & (den > 0 ),
382+ np .sqrt (rel_err (h_data [num_cat ]) +
383+ + rel_err (h_mc [den_cat ]) +
384+ + rel_err (h_data [num_cat ]) +
385+ + rel_err (h_mc [den_cat ])) * ff_val ** 2 ,
386+ 0.5 * np .ones_like (ff_val ))
387+ h = hist .Hist .new
388+ for (var_name , var_axis ) in self .config_inst .x .fake_factor_method .axes .items ():
389+ h = eval (f'h.{ var_axis .ax_str } ' )
390+ h = h .StrCategory (['nominal' , 'up' , 'down' ], name = 'syst' , label = 'Statistical uncertainty of the fake factor' )
391+ ff_hist = h .Weight ()
392+ ff_hist .view ().value [...,0 ] = ff_val
393+ ff_hist .view ().value [...,1 ] = ff_val + np .sqrt (ff_err2 )
394+ ff_hist .view ().value [...,2 ] = np .maximum (ff_val - np .sqrt (ff_err2 ),0 )
395+ ff_hist .name = name
396+ ff_hist .label = label
397+ ff_corr = cl_convert .from_histogram (ff_hist ) #temporary correction without systematic axis
398+ ff_corr .data .flow = "clamp"
399+ return ff_corr , ff_hist
400+
401+ import rich
402+
403+ wj_corr , wj_h = get_ff_corr (self ,
404+ data_h_cat ,
405+ mc_h_cat ,
406+ num_cat = 'dr_num_wj' ,
407+ den_cat = 'dr_den_wj' ,
408+ name = 'ff_wjets' ,
409+ label = 'Fake factor W+jets' )
410+
411+ qcd_corr , qcd_h = get_ff_corr (self ,
412+ data_h_cat ,
413+ mc_h_cat ,
414+ num_cat = 'dr_num_qcd' ,
415+ den_cat = 'dr_den_qcd' ,
416+ name = 'ff_qcd' ,
417+ label = 'Fake factor QCD' )
418+
419+ for h_name in ['wj' , 'qcd' ]:
420+ the_hist = eval (f'{ h_name } _h' )
421+
422+ for syst in ['nominal' ,'up' ,'down' ]:
423+ fig , ax = plt .subplots (figsize = (12 , 8 ))
424+ the_hist [...,syst ].plot2d (ax = ax )
425+ self .output ()['plots' ][syst ].dump (fig , formatter = "mpl" )
426+
427+
428+ self .output ()['ff_json' ]['wj' ].dump (wj_corr .json (exclude_unset = True ), formatter = "json" )
429+ self .output ()['ff_json' ]['qcd' ].dump (qcd_corr .json (exclude_unset = True ), formatter = "json" )
430+
431+
432+
433+
434+
435+
436+
437+
438+
439+
374440
0 commit comments