Skip to content

Commit

Permalink
Add numba tests as examples
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Feb 3, 2025
1 parent 5378cb8 commit 7f75ffd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,6 @@ def test_text_conversion():
s = pd.Series([["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"]])
s2 = s.ak.str.strptime()
s3 = s2.ak.dt.strftime("%FT%T")
assert s3.tolist() == [["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"]]
# remove trailing zeros - depends on system defaults
out = [None if _ is None else _.split(".")[0] for _ in s3.tolist()[0]]
assert out == ["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"]
22 changes: 22 additions & 0 deletions tests/test_ray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import awkward as ak
import numpy as np
import pytest

Expand Down Expand Up @@ -124,3 +125,24 @@ def test_overload(rayc):
def test_dir(df):
assert "flatten" in dir(df.ak)
assert "upper" in dir(df.ak.str)


def test_apply_numba(df):
numba = pytest.importorskip("numba")

@numba.njit()
def f(data: ak.Array, builder: ak.ArrayBuilder) -> None:
for i, item in enumerate(data.x):
if item[0] is None:
builder.append(None)
else:
builder.append(item[0][2] + item[2][0]) # always 6

def f2(data):
builder = ak.ArrayBuilder()
f(data, builder)
return builder.snapshot()

out = df.ak.apply(f2, where="x")
result = out.ak.to_output()
assert result.ak.tolist() == [6, None] * 100
25 changes: 25 additions & 0 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import awkward as ak
import numpy as np
import pytest

Expand Down Expand Up @@ -134,3 +135,27 @@ def test_overload(spark):
def test_dir(df):
assert "flatten" in dir(df.ak)
assert "upper" in dir(df.ak.str)


def test_apply_numba(df):
numba = pytest.importorskip("numba")

def f(data: ak.Array, builder: ak.ArrayBuilder) -> None:
for i, item in enumerate(data.x):
if item[0] is None:
builder.append(None)
else:
builder.append(item[0][2] + item[2][0]) # always 6

def f2(data):
if len(data):
builder = ak.ArrayBuilder()
numba.njit(f)(data, builder)
return builder.snapshot()
else:
# default output for zero-length schema guesser
return ak.Array([None, 6])

out = df.ak.apply(f2, where="x")
result = out.ak.to_output()
assert result.ak.tolist() == [6, None] * 100

0 comments on commit 7f75ffd

Please sign in to comment.