Skip to content

Commit d7ff944

Browse files
committed
Developed a task to calculate fake factors for WJ and QCD
1 parent b0dc5f6 commit d7ff944

1 file changed

Lines changed: 119 additions & 53 deletions

File tree

columnflow/tasks/data_driven_methods.py

Lines changed: 119 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from columnflow.util import dev_sandbox, DotDict
2424

2525

26-
class CreateFakeFactorHistograms(
26+
class PrepareFakeFactorHistograms(
2727
VariablesMixin,
2828
WeightProducerMixin,
2929
ProducersMixin,
@@ -177,8 +177,7 @@ def run(self):
177177

178178
h = (hist.Hist.new
179179
.IntCat([], name="category", growth=True)
180-
.IntCat([], name="process", growth=True)
181-
.IntCat([], name="shift", growth=True))
180+
.IntCat([], name="process", growth=True))
182181
for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items():
183182
h = eval(f'h.{var_axis.ax_str}')
184183

@@ -189,11 +188,11 @@ def run(self):
189188
axis=-1,
190189
)
191190
# broadcast arrays so that each event can be filled for all its categories
191+
192192
fill_data = {
193193
"category" : category_ids,
194194
"process" : events.process_id,
195-
"shift" : np.ones(len(events), dtype=np.int32) * self.global_shift_inst.id,
196-
"weight": weight,
195+
"weight" : weight,
197196
}
198197
for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items():
199198
route = Route(var_axis.var_route)
@@ -214,19 +213,19 @@ def run(self):
214213

215214
# overwrite class defaults
216215
check_overlap_tasks = law.config.get_expanded("analysis", "check_overlapping_inputs", [], split_csv=True)
217-
CreateFakeFactorHistograms.check_overlapping_inputs = ChunkedIOMixin.check_overlapping_inputs.copy(
218-
default=CreateFakeFactorHistograms.task_family in check_overlap_tasks,
216+
PrepareFakeFactorHistograms.check_overlapping_inputs = ChunkedIOMixin.check_overlapping_inputs.copy(
217+
default=PrepareFakeFactorHistograms.task_family in check_overlap_tasks,
219218
add_default_to_description=True,
220219
)
221220

222221

223-
CreateFakeFactorHistogramsWrapper = wrapper_factory(
222+
PrepareFakeFactorHistogramsWrapper = wrapper_factory(
224223
base_cls=AnalysisTask,
225-
require_cls=CreateFakeFactorHistograms,
224+
require_cls=PrepareFakeFactorHistograms,
226225
enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"],
227226
)
228227

229-
class MergeFakeFactors(
228+
class ComputeFakeFactors(
230229
VariablesMixin,
231230
DatasetsProcessesMixin,
232231
CategoriesMixin,
@@ -253,12 +252,12 @@ class MergeFakeFactors(
253252
# upstream requirements
254253
reqs = Requirements(
255254
RemoteWorkflow.reqs,
256-
CreateFakeFactorHistograms=CreateFakeFactorHistograms,
255+
PrepareFakeFactorHistograms=PrepareFakeFactorHistograms,
257256
)
258257

259258
def store_parts(self):
260259
parts = super().store_parts()
261-
parts.insert_before("version", "datasets" )#, f"datasets_{self.datasets_repr}")
260+
parts.insert_before("version", "datasets", f"datasets_{self.datasets_repr}")
262261
return parts
263262

264263
@classmethod
@@ -291,7 +290,7 @@ def workflow_requires(self):
291290
if not self.pilot:
292291
variables = self._get_variables()
293292
if variables:
294-
reqs["ff_method"] = self.reqs.CreateFakeFactorHistograms.req_different_branching(
293+
reqs["ff_method"] = self.reqs.PrepareFakeFactorHistograms.req_different_branching(
295294
self,
296295
branch=-1,
297296
variables=tuple(variables),
@@ -301,74 +300,141 @@ def workflow_requires(self):
301300

302301
def requires(self):
303302
return {
304-
d: self.reqs.CreateFakeFactorHistograms.req(
303+
d: self.reqs.PrepareFakeFactorHistograms.req(
305304
self,
306305
dataset=d,
307306
branch=-1,
308307
)
309308
for d in self.datasets
310309
}
311310
def output(self):
312-
return {"hists": self.target(f"fake_factors.pickle")}
311+
return {"ff_json": {ff_type: self.target(f"fake_factors_{ff_type}.json")for ff_type in ['qcd','wj']},
312+
"plots": {syst: self.target(f"fake_factor_syst_{syst}.png") for syst in ['nominal', 'up', 'down']},}
313313

314314
@law.decorator.log
315315
def run(self):
316316
import hist
317317
import numpy as np
318318
import matplotlib.pyplot as plt
319+
import correctionlib.convert as cl_convert
319320
# preare inputs and outputs
320321
inputs = self.input()
321322
outputs = self.output()
322323
merged_per_dataset = {}
323324
projected_hists = []
325+
hists_by_dataset = []
324326
for (dataset_name, dataset) in inputs.items():
325327
files = dataset['collection']
326328
# load input histograms per dataset
327-
hists = [
329+
hists_per_ds = [
328330
inp['hists'].load(formatter="pickle")['fake_factors']
329331
for inp in self.iter_progress(files.targets.values(), len(files), reach=(0, 50))
330332
]
331333
self.publish_message(f"merging Fake factor histograms for {dataset_name}")
332-
the_hist = sum(hists[1:], hists[0].copy())
333-
merged_per_dataset[dataset_name] = the_hist
334-
#Get axes names excluding 'process'. This is needed to merge hists for different processes
335-
ax_names = [ax_name for ax_name in the_hist.axes.name if ax_name != 'process']
336-
#Remove 'process' axis by projecting hist on the remaining axes
337-
projected_hists.append(the_hist.project(*ax_names))
338-
merged_hist = sum(projected_hists[1:], projected_hists[0].copy())
334+
ds_single_hist = sum(hists_per_ds[1:], hists_per_ds[0].copy())
335+
hists_by_dataset.append(ds_single_hist)
336+
337+
hists_by_proc = {}
338+
for proc_name in self.config_inst.processes.names():
339+
proc = self.config_inst.processes.get(proc_name)
340+
self.publish_message(f"merging Fake factor histograms for process: {proc.name}")
341+
for the_hist in hists_by_dataset:
342+
343+
if proc.id in the_hist.axes["process"]:
344+
h = the_hist.copy()
345+
h = h[{"process": hist.loc(proc.id)}]
346+
# add the histogram
347+
if proc in hists_by_proc:
348+
hists_by_proc[proc] += h
349+
else:
350+
hists_by_proc[proc] = h
339351

340-
cat_SR = self.config_inst.get_category(self.branch_data.category)
341-
cat_DR_den = self.config_inst.get_category(cat_SR.x.DR_den)
342-
cat_DR_num = self.config_inst.get_category(cat_SR.x.DR_num)
352+
mc_hists = [h for p, h in hists_by_proc.items() if p.is_mc and not p.has_tag("signal")]
353+
data_hists = [h for p, h in hists_by_proc.items() if p.is_data]
343354

344-
def get_hist (h, category):
345-
return h[{"category": hist.loc(category.id)}]
355+
mc_hists = sum(mc_hists[1:], mc_hists[0].copy())
356+
data_hists = sum(data_hists[1:], data_hists[0].copy())
346357

347-
h_DR_num = get_hist(merged_hist,cat_DR_num).values()
348-
h_DR_den = get_hist(merged_hist,cat_DR_den).values()
358+
dr_names = ['dr_num_wj','dr_den_wj','dr_num_qcd','dr_den_qcd']
359+
360+
def get_hist(h, category):
361+
return h[{"category": hist.loc(category.id)}]
349362

350-
ff_values = np.where((h_DR_num > 0) & (h_DR_den > 0),
351-
h_DR_num / np.maximum(h_DR_den, 1),
352-
0.0,
353-
)
354363

355-
#For the control: make 2d hists and plot them:
356-
hist2d = merged_hist.project('tau_pt','tau_dm_pnet')
357-
ff_hist = hist.Hist(*hist2d.axes, data=ff_values[0])
358-
fig, ax = plt.subplots(figsize=(12, 8))
359-
ff_hist.plot2d(ax=ax)
360-
plt.savefig('fake_factors.pdf')
361-
from IPython import embed; embed()
362-
#outputs["hists"][variable_name].dump(merged, formatter="pickle")F
363-
364-
# optionally remove inputs
365-
if self.remove_previous:
366-
inputs.remove()
367-
368-
369-
# MergeFakeFactorsWrapper = wrapper_factory(
370-
# base_cls=AnalysisTask,
371-
# require_cls=MergeFakeFactors,
372-
# enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"],
373-
# )
364+
#Create two dictionaries that contain histograms for different determination regions
365+
data_h_cat ={}
366+
mc_h_cat = {}
367+
for dr_name in dr_names:
368+
cat = self.config_inst.get_category(self.branch_data.category.replace('sr',dr_name))
369+
data_h_cat[dr_name] = get_hist(data_hists, cat)
370+
mc_h_cat[dr_name] = get_hist(mc_hists, cat)
371+
372+
373+
def get_ff_corr(self, h_data, h_mc, num_cat, den_cat, name='ff_hist', label='ff_hist'):
374+
num = h_data[num_cat].values() - h_mc[num_cat].values()
375+
den = h_data[den_cat].values() - h_mc[den_cat].values()
376+
ff_val = np.where((num > 0) & (den > 0),
377+
num / np.maximum(den, 1),
378+
1)
379+
def rel_err(x):
380+
return x.variances()/np.maximum(x.values()**2, 1)
381+
ff_err2 = np.where((num > 0) & (den > 0),
382+
np.sqrt(rel_err(h_data[num_cat]) +
383+
+ rel_err(h_mc[den_cat]) +
384+
+ rel_err(h_data[num_cat]) +
385+
+ rel_err(h_mc[den_cat])) * ff_val**2,
386+
0.5* np.ones_like(ff_val))
387+
h = hist.Hist.new
388+
for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items():
389+
h = eval(f'h.{var_axis.ax_str}')
390+
h = h.StrCategory(['nominal', 'up', 'down'], name='syst', label='Statistical uncertainty of the fake factor')
391+
ff_hist= h.Weight()
392+
ff_hist.view().value[...,0] = ff_val
393+
ff_hist.view().value[...,1] = ff_val + np.sqrt(ff_err2)
394+
ff_hist.view().value[...,2] = np.maximum(ff_val - np.sqrt(ff_err2),0)
395+
ff_hist.name = name
396+
ff_hist.label = label
397+
ff_corr = cl_convert.from_histogram(ff_hist) #temporary correction without systematic axis
398+
ff_corr.data.flow = "clamp"
399+
return ff_corr, ff_hist
400+
401+
import rich
402+
403+
wj_corr, wj_h = get_ff_corr(self,
404+
data_h_cat,
405+
mc_h_cat,
406+
num_cat = 'dr_num_wj',
407+
den_cat = 'dr_den_wj',
408+
name='ff_wjets',
409+
label='Fake factor W+jets')
410+
411+
qcd_corr, qcd_h = get_ff_corr(self,
412+
data_h_cat,
413+
mc_h_cat,
414+
num_cat = 'dr_num_qcd',
415+
den_cat = 'dr_den_qcd',
416+
name='ff_qcd',
417+
label='Fake factor QCD')
418+
419+
for h_name in ['wj', 'qcd']:
420+
the_hist = eval(f'{h_name}_h')
421+
422+
for syst in ['nominal','up','down']:
423+
fig, ax = plt.subplots(figsize=(12, 8))
424+
the_hist[...,syst].plot2d(ax=ax)
425+
self.output()['plots'][syst].dump(fig, formatter="mpl")
426+
427+
428+
self.output()['ff_json']['wj'].dump(wj_corr.json(exclude_unset=True), formatter="json")
429+
self.output()['ff_json']['qcd'].dump(qcd_corr.json(exclude_unset=True), formatter="json")
430+
431+
432+
433+
434+
435+
436+
437+
438+
439+
374440

0 commit comments

Comments
 (0)