Skip to content

Commit

Permalink
avoid pickling of numba objects
Browse files Browse the repository at this point in the history
avoid pickling of numba objects, use type hints to speed compilation
  • Loading branch information
lgray committed Jan 27, 2025
1 parent 1f6c06e commit d6f32fc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 24 deletions.
93 changes: 70 additions & 23 deletions src/coffea/lumi_tools/lumi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,34 @@
from numba.typed import Dict


def wrap_get_lumi(runlumis, lumi_index):
runlumis_or_lz = awkward.typetracer.length_zero_if_typetracer(runlumis).to_numpy()
def _make_lumi_mask_dict():
return Dict.empty(key_type=types.uint32, value_type=types.uint32[:])


def _make_lumi_data_dict():
return Dict.empty(
key_type=types.Tuple([types.uint32, types.uint32]),
value_type=types.float64,
)


_lumi_mask_dict_type = numba.typeof(_make_lumi_mask_dict())

_lumi_data_dict_type = numba.typeof(_make_lumi_data_dict())


def wrap_get_lumi(runlumis, lumi_index_astuple):
runlumis_or_lz = (
awkward.typetracer.length_zero_if_typetracer(runlumis)
.to_numpy()
.astype(numpy.uint32)
)
wrap_tot_lumi = numpy.zeros((1,))
lumi_index = _make_lumi_data_dict()
if isinstance(lumi_index_astuple, tuple):
LumiData._build_lumi_table_kernel(*lumi_index_astuple, lumi_index)
else:
lumi_index = lumi_index_astuple
LumiData._get_lumi_kernel(
runlumis_or_lz[:, 0], runlumis_or_lz[:, 1], lumi_index, wrap_tot_lumi
)
Expand Down Expand Up @@ -84,20 +109,24 @@ def get_lumi(self, runlumis):
-------
(float) The total integrated luminosity of the runs and lumisections indicated in `runlumis`.
"""

if isinstance(runlumis, LumiList):
runlumis = runlumis.array

if self.index is None:
self.index = Dict.empty(
key_type=types.Tuple([types.uint32, types.uint32]),
value_type=types.float64,
)
self.index = _make_lumi_data_dict()
runs = self._lumidata[:, 0].astype("u4")
lumis = self._lumidata[:, 1].astype("u4")
# fill self.index
LumiData._build_lumi_table_kernel(runs, lumis, self._lumidata, self.index)
LumiData._build_lumi_table_kernel(
runs, lumis, self._lumidata[:, 2], self.index
)
# delayed object cache
self.index_delayed = dask.delayed(self.index)
if isinstance(runlumis, dask_awkward.Array):
self.index_delayed = dask.delayed(
tuple([runs, lumis, self._lumidata[:, 2]])
)

if isinstance(runlumis, LumiList):
runlumis = runlumis.array
tot_lumi = numpy.zeros((1,), dtype=numpy.dtype("float64"))
if isinstance(runlumis, dask_awkward.Array):
lumi_meta = wrap_get_lumi(runlumis._meta, self.index)
Expand All @@ -120,20 +149,32 @@ def get_lumi(self, runlumis):
)

@staticmethod
@numba.njit(parallel=False, fastmath=False)
@numba.njit(
types.void(
types.uint32[:], types.uint32[:], types.float64[:], _lumi_data_dict_type
),
parallel=False,
fastmath=False,
)
def _build_lumi_table_kernel(runs, lumis, lumidata, index):
for i in range(len(runs)):
run = runs[i]
lumi = lumis[i]
index[(run, lumi)] = float(lumidata[i, 2])
index[(run, lumi)] = lumidata[i]

@staticmethod
@numba.njit(parallel=False, fastmath=False)
@numba.njit(
types.void(
types.uint32[:], types.uint32[:], _lumi_data_dict_type, types.float64[:]
),
parallel=False,
fastmath=False,
)
def _get_lumi_kernel(runs, lumis, index, tot_lumi):
ks_done = set()
for iev in range(len(runs)):
run = numpy.uint32(runs[iev])
lumi = numpy.uint32(lumis[iev])
run = runs[iev]
lumi = lumis[iev]
k = (run, lumi)
if k not in ks_done:
ks_done.add(k)
Expand All @@ -156,7 +197,7 @@ def __init__(self, jsonfile):
with fsspec.open(jsonfile) as fin:
goldenjson = json.load(fin)

self._masks = {}
self._masks = dict()

for run, lumilist in goldenjson.items():
mask = numpy.array(lumilist, dtype=numpy.uint32).flatten()
Expand All @@ -183,20 +224,20 @@ def __call__(self, runs, lumis):

def apply(runs, lumis):
# fill numba typed dict
_masks = Dict.empty(key_type=types.uint32, value_type=types.uint32[:])
_masks = _make_lumi_mask_dict()
for k, v in self._masks.items():
_masks[k] = v

runs_orig = runs
if isinstance(runs, awkward.highlevel.Array):
runs = awkward.to_numpy(
awkward.typetracer.length_zero_if_typetracer(runs)
)
).astype(numpy.uint32)
if isinstance(lumis, awkward.highlevel.Array):
lumis = awkward.to_numpy(
awkward.typetracer.length_zero_if_typetracer(lumis)
)
mask_out = numpy.zeros(dtype="bool", shape=runs.shape)
).astype(numpy.uint32)
mask_out = numpy.zeros(dtype=bool, shape=runs.shape)
LumiMask._apply_run_lumi_mask_kernel(_masks, runs, lumis, mask_out)
if isinstance(runs_orig, awkward.Array):
mask_out = awkward.Array(mask_out)
Expand All @@ -213,7 +254,13 @@ def apply(runs, lumis):

# This could be run in parallel, but windows does not support it
@staticmethod
@numba.njit(parallel=False, fastmath=True)
@numba.njit(
types.void(
_lumi_mask_dict_type, types.uint32[:], types.uint32[:], types.bool[:]
),
parallel=False,
fastmath=True,
)
def _apply_run_lumi_mask_kernel(masks, runs, lumis, mask_out):
for iev in numba.prange(len(runs)):
run = numpy.uint32(runs[iev])
Expand Down Expand Up @@ -307,7 +354,7 @@ def __init__(self, runs=None, lumis=None, delayed=True):

self.array = None
if not delayed:
self.array = numpy.zeros(shape=(0, 2))
self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32)

if isinstance(runs, dask_awkward.Array) and isinstance(
lumis, dask_awkward.Array
Expand Down Expand Up @@ -342,4 +389,4 @@ def clear(self):
"""Clear current lumi list"""
if isinstance(self.array, dask_awkward.Array):
raise RuntimeError("Delayed-mode LumiList cannot be cleared!")
self.array = numpy.zeros(shape=(0, 2))
self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32)
2 changes: 1 addition & 1 deletion tests/test_lumi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_lumidata():
pyruns = ld._lumidata[:, 0].astype("u4")
pylumis = ld._lumidata[:, 1].astype("u4")
LumiData._build_lumi_table_kernel.py_func(
pyruns, pylumis, ld._lumidata, py_index
pyruns, pylumis, ld._lumidata[:, 2], py_index
)

assert len(py_index) == len(ld.index)
Expand Down

0 comments on commit d6f32fc

Please sign in to comment.