Skip to content

Commit

Permalink
Merge pull request #1259 from scikit-hep/lumi_tools_numba_pickle
Browse files Browse the repository at this point in the history
fix: avoid pickling of numba objects
  • Loading branch information
lgray authored Jan 29, 2025
2 parents 6159ef5 + ecd1741 commit 7c551ea
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 24 deletions.
93 changes: 70 additions & 23 deletions src/coffea/lumi_tools/lumi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,34 @@
from numba.typed import Dict


def wrap_get_lumi(runlumis, lumi_index):
runlumis_or_lz = awkward.typetracer.length_zero_if_typetracer(runlumis).to_numpy()
def _make_lumi_mask_dict():
return Dict.empty(key_type=types.uint32, value_type=types.uint32[:])


def _make_lumi_data_dict():
return Dict.empty(
key_type=types.Tuple([types.uint32, types.uint32]),
value_type=types.float64,
)


_lumi_mask_dict_type = numba.typeof(_make_lumi_mask_dict())

_lumi_data_dict_type = numba.typeof(_make_lumi_data_dict())


def wrap_get_lumi(runlumis, lumi_index_astuple):
runlumis_or_lz = (
awkward.typetracer.length_zero_if_typetracer(runlumis)
.to_numpy()
.astype(numpy.uint32)
)
wrap_tot_lumi = numpy.zeros((1,))
lumi_index = _make_lumi_data_dict()
if isinstance(lumi_index_astuple, tuple):
LumiData._build_lumi_table_kernel(*lumi_index_astuple, lumi_index)
else:
lumi_index = lumi_index_astuple
LumiData._get_lumi_kernel(
runlumis_or_lz[:, 0], runlumis_or_lz[:, 1], lumi_index, wrap_tot_lumi
)
Expand Down Expand Up @@ -84,20 +109,24 @@ def get_lumi(self, runlumis):
-------
(float) The total integrated luminosity of the runs and lumisections indicated in `runlumis`.
"""

if isinstance(runlumis, LumiList):
runlumis = runlumis.array

if self.index is None:
self.index = Dict.empty(
key_type=types.Tuple([types.uint32, types.uint32]),
value_type=types.float64,
)
self.index = _make_lumi_data_dict()
runs = self._lumidata[:, 0].astype("u4")
lumis = self._lumidata[:, 1].astype("u4")
# fill self.index
LumiData._build_lumi_table_kernel(runs, lumis, self._lumidata, self.index)
LumiData._build_lumi_table_kernel(
runs, lumis, self._lumidata[:, 2], self.index
)
# delayed object cache
self.index_delayed = dask.delayed(self.index)
if isinstance(runlumis, dask_awkward.Array):
self.index_delayed = dask.delayed(
tuple([runs, lumis, self._lumidata[:, 2]])
)

if isinstance(runlumis, LumiList):
runlumis = runlumis.array
tot_lumi = numpy.zeros((1,), dtype=numpy.dtype("float64"))
if isinstance(runlumis, dask_awkward.Array):
lumi_meta = wrap_get_lumi(runlumis._meta, self.index)
Expand All @@ -120,20 +149,32 @@ def get_lumi(self, runlumis):
)

@staticmethod
@numba.njit(parallel=False, fastmath=False)
@numba.njit(
types.void(
types.uint32[:], types.uint32[:], types.float64[:], _lumi_data_dict_type
),
parallel=False,
fastmath=False,
)
def _build_lumi_table_kernel(runs, lumis, lumidata, index):
for i in range(len(runs)):
run = runs[i]
lumi = lumis[i]
index[(run, lumi)] = float(lumidata[i, 2])
index[(run, lumi)] = lumidata[i]

@staticmethod
@numba.njit(parallel=False, fastmath=False)
@numba.njit(
types.void(
types.uint32[:], types.uint32[:], _lumi_data_dict_type, types.float64[:]
),
parallel=False,
fastmath=False,
)
def _get_lumi_kernel(runs, lumis, index, tot_lumi):
ks_done = set()
for iev in range(len(runs)):
run = numpy.uint32(runs[iev])
lumi = numpy.uint32(lumis[iev])
run = runs[iev]
lumi = lumis[iev]
k = (run, lumi)
if k not in ks_done:
ks_done.add(k)
Expand All @@ -156,7 +197,7 @@ def __init__(self, jsonfile):
with fsspec.open(jsonfile) as fin:
goldenjson = json.load(fin)

self._masks = {}
self._masks = dict()

for run, lumilist in goldenjson.items():
mask = numpy.array(lumilist, dtype=numpy.uint32).flatten()
Expand All @@ -183,20 +224,20 @@ def __call__(self, runs, lumis):

def apply(runs, lumis):
# fill numba typed dict
_masks = Dict.empty(key_type=types.uint32, value_type=types.uint32[:])
_masks = _make_lumi_mask_dict()
for k, v in self._masks.items():
_masks[k] = v

runs_orig = runs
if isinstance(runs, awkward.highlevel.Array):
runs = awkward.to_numpy(
awkward.typetracer.length_zero_if_typetracer(runs)
)
).astype(numpy.uint32)
if isinstance(lumis, awkward.highlevel.Array):
lumis = awkward.to_numpy(
awkward.typetracer.length_zero_if_typetracer(lumis)
)
mask_out = numpy.zeros(dtype="bool", shape=runs.shape)
).astype(numpy.uint32)
mask_out = numpy.zeros(dtype=bool, shape=runs.shape)
LumiMask._apply_run_lumi_mask_kernel(_masks, runs, lumis, mask_out)
if isinstance(runs_orig, awkward.Array):
mask_out = awkward.Array(mask_out)
Expand All @@ -213,7 +254,13 @@ def apply(runs, lumis):

# This could be run in parallel, but windows does not support it
@staticmethod
@numba.njit(parallel=False, fastmath=True)
@numba.njit(
types.void(
_lumi_mask_dict_type, types.uint32[:], types.uint32[:], types.bool[:]
),
parallel=False,
fastmath=True,
)
def _apply_run_lumi_mask_kernel(masks, runs, lumis, mask_out):
for iev in numba.prange(len(runs)):
run = numpy.uint32(runs[iev])
Expand Down Expand Up @@ -307,7 +354,7 @@ def __init__(self, runs=None, lumis=None, delayed=True):

self.array = None
if not delayed:
self.array = numpy.zeros(shape=(0, 2))
self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32)

if isinstance(runs, dask_awkward.Array) and isinstance(
lumis, dask_awkward.Array
Expand Down Expand Up @@ -342,4 +389,4 @@ def clear(self):
"""Clear current lumi list"""
if isinstance(self.array, dask_awkward.Array):
raise RuntimeError("Delayed-mode LumiList cannot be cleared!")
self.array = numpy.zeros(shape=(0, 2))
self.array = numpy.zeros(shape=(0, 2), dtype=numpy.uint32)
4 changes: 4 additions & 0 deletions tests/samples/small_lumi.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#Data tag : 24v2 , Norm tag: None
#run:fill,ls,time,beamstatus,E(GeV),delivered(/pb),recorded(/pb),avgpu,source
370790:9073,54:54,07/16/23 22:51:10,STABLE BEAMS,6800,0.265372231,0.005674056,33.0,AVG
368229:8853,74:74,05/31/23 00:23:51,STABLE BEAMS,6800,0.106034314,0.101855712,27.4,AVG
25 changes: 24 additions & 1 deletion tests/test_lumi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_lumidata():
pyruns = ld._lumidata[:, 0].astype("u4")
pylumis = ld._lumidata[:, 1].astype("u4")
LumiData._build_lumi_table_kernel.py_func(
pyruns, pylumis, ld._lumidata, py_index
pyruns, pylumis, ld._lumidata[:, 2], py_index
)

assert len(py_index) == len(ld.index)
Expand Down Expand Up @@ -175,3 +175,26 @@ def test_lumilist_client_fromfile():
(result,) = dask.compute(lumilist.array)

assert result.to_list() == [[1, 13889]]


def test_1259_avoid_pickle_numba_dict():

runs_eager = ak.Array([368229, 368229, 368229, 368229])
runs = dak.from_awkward(runs_eager, 2)
lumis_eager = ak.Array([74, 74, 74, 74])
lumis = dak.from_awkward(lumis_eager, 2)

def count_lumi(runs, lumis):
total_lumi = 0
my_lumilist = LumiList(runs, lumis)
my_lumidata = LumiData("tests/samples/small_lumi.csv")
total_lumi += my_lumidata.get_lumi(my_lumilist)
return total_lumi

noclient_output = dask.compute(count_lumi(runs, lumis))[0]

with Client() as _:
output = count_lumi(runs, lumis)
client_output = dask.compute(output)[0]

assert noclient_output == client_output

0 comments on commit 7c551ea

Please sign in to comment.