diff --git a/src/coffea/lumi_tools/lumi_tools.py b/src/coffea/lumi_tools/lumi_tools.py index 29170d554..32542af8c 100644 --- a/src/coffea/lumi_tools/lumi_tools.py +++ b/src/coffea/lumi_tools/lumi_tools.py @@ -15,9 +15,34 @@ from numba.typed import Dict -def wrap_get_lumi(runlumis, lumi_index): - runlumis_or_lz = awkward.typetracer.length_zero_if_typetracer(runlumis).to_numpy() +def _make_lumi_mask_dict(): + return Dict.empty(key_type=types.uint32, value_type=types.uint32[:]) + + +def _make_lumi_data_dict(): + return Dict.empty( + key_type=types.Tuple([types.uint32, types.uint32]), + value_type=types.float64, + ) + + +_lumi_mask_dict_type = numba.typeof(_make_lumi_mask_dict()) + +_lumi_data_dict_type = numba.typeof(_make_lumi_data_dict()) + + +def wrap_get_lumi(runlumis, lumi_index_astuple): + runlumis_or_lz = ( + awkward.typetracer.length_zero_if_typetracer(runlumis) + .to_numpy() + .astype(numpy.uint32) + ) wrap_tot_lumi = numpy.zeros((1,)) + lumi_index = _make_lumi_data_dict() + if isinstance(lumi_index_astuple, tuple): + LumiData._build_lumi_table_kernel(*lumi_index_astuple, lumi_index) + else: + lumi_index = lumi_index_astuple LumiData._get_lumi_kernel( runlumis_or_lz[:, 0], runlumis_or_lz[:, 1], lumi_index, wrap_tot_lumi ) @@ -84,20 +109,24 @@ def get_lumi(self, runlumis): ------- (float) The total integrated luminosity of the runs and lumisections indicated in `runlumis`. """ + + if isinstance(runlumis, LumiList): + runlumis = runlumis.array + if self.index is None: - self.index = Dict.empty( - key_type=types.Tuple([types.uint32, types.uint32]), - value_type=types.float64, - ) + self.index = _make_lumi_data_dict() runs = self._lumidata[:, 0].astype("u4") lumis = self._lumidata[:, 1].astype("u4") # fill self.index - LumiData._build_lumi_table_kernel(runs, lumis, self._lumidata, self.index) + LumiData._build_lumi_table_kernel( + runs, lumis, self._lumidata[:, 2], self.index + ) # delayed object cache - self.index_delayed = dask.delayed(self.index) + if isinstance(runlumis, dask_awkward.Array): + self.index_delayed = dask.delayed( + tuple([runs, lumis, self._lumidata[:, 2]]) + ) - if isinstance(runlumis, LumiList): - runlumis = runlumis.array tot_lumi = numpy.zeros((1,), dtype=numpy.dtype("float64")) if isinstance(runlumis, dask_awkward.Array): lumi_meta = wrap_get_lumi(runlumis._meta, self.index) @@ -120,20 +149,32 @@ def get_lumi(self, runlumis): ) @staticmethod - @numba.njit(parallel=False, fastmath=False) + @numba.njit( + types.void( + types.uint32[:], types.uint32[:], types.float64[:], _lumi_data_dict_type + ), + parallel=False, + fastmath=False, + ) def _build_lumi_table_kernel(runs, lumis, lumidata, index): for i in range(len(runs)): run = runs[i] lumi = lumis[i] - index[(run, lumi)] = float(lumidata[i, 2]) + index[(run, lumi)] = lumidata[i] @staticmethod - @numba.njit(parallel=False, fastmath=False) + @numba.njit( + types.void( + types.uint32[:], types.uint32[:], _lumi_data_dict_type, types.float64[:] + ), + parallel=False, + fastmath=False, + ) def _get_lumi_kernel(runs, lumis, index, tot_lumi): ks_done = set() for iev in range(len(runs)): - run = numpy.uint32(runs[iev]) - lumi = numpy.uint32(lumis[iev]) + run = runs[iev] + lumi = lumis[iev] k = (run, lumi) if k not in ks_done: ks_done.add(k) @@ -156,7 +197,7 @@ def __init__(self, jsonfile): with fsspec.open(jsonfile) as fin: goldenjson = json.load(fin) - self._masks = {} + self._masks = dict() for run, lumilist in goldenjson.items(): mask = numpy.array(lumilist, dtype=numpy.uint32).flatten() @@ -183,7 +224,7 @@ def __call__(self, runs, lumis): def apply(runs, lumis): # fill numba typed dict - _masks = Dict.empty(key_type=types.uint32, value_type=types.uint32[:]) + _masks = _make_lumi_mask_dict() for k, v in self._masks.items(): _masks[k] = v @@ -191,12 +232,12 @@ def apply(runs, lumis): if isinstance(runs, awkward.highlevel.Array): runs = awkward.to_numpy( awkward.typetracer.length_zero_if_typetracer(runs) - ) + ).astype(numpy.uint32) if isinstance(lumis, awkward.highlevel.Array): lumis = awkward.to_numpy( awkward.typetracer.length_zero_if_typetracer(lumis) - ) - mask_out = numpy.zeros(dtype="bool", shape=runs.shape) + ).astype(numpy.uint32) + mask_out = numpy.zeros(dtype=bool, shape=runs.shape) LumiMask._apply_run_lumi_mask_kernel(_masks, runs, lumis, mask_out) if isinstance(runs_orig, awkward.Array): mask_out = awkward.Array(mask_out) @@ -213,7 +254,13 @@ def apply(runs, lumis): # This could be run in parallel, but windows does not support it @staticmethod - @numba.njit(parallel=False, fastmath=True) + @numba.njit( + types.void( + _lumi_mask_dict_type, types.uint32[:], types.uint32[:], types.bool[:] + ), + parallel=False, + fastmath=True, + ) def _apply_run_lumi_mask_kernel(masks, runs, lumis, mask_out): for iev in numba.prange(len(runs)): run = numpy.uint32(runs[iev]) @@ -307,7 +354,7 @@ def __init__(self, runs=None, lumis=None, delayed=True): self.array = None if not delayed: - self.array = numpy.zeros(shape=(0, 2)) + self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32) if isinstance(runs, dask_awkward.Array) and isinstance( lumis, dask_awkward.Array @@ -342,4 +389,4 @@ def clear(self): """Clear current lumi list""" if isinstance(self.array, dask_awkward.Array): raise RuntimeError("Delayed-mode LumiList cannot be cleared!") - self.array = numpy.zeros(shape=(0, 2)) + self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32) diff --git a/tests/test_lumi_tools.py b/tests/test_lumi_tools.py index 989c86acf..eed65d845 100644 --- a/tests/test_lumi_tools.py +++ b/tests/test_lumi_tools.py @@ -40,7 +40,7 @@ def test_lumidata(): pyruns = ld._lumidata[:, 0].astype("u4") pylumis = ld._lumidata[:, 1].astype("u4") LumiData._build_lumi_table_kernel.py_func( - pyruns, pylumis, ld._lumidata, py_index + pyruns, pylumis, ld._lumidata[:, 2], py_index ) assert len(py_index) == len(ld.index)