Skip to content

Commit

Permalink
add tests for aggfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel.oranyeli committed Sep 17, 2024
1 parent bfea362 commit af7ef10
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 24 deletions.
48 changes: 25 additions & 23 deletions janitor/functions/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,29 +564,6 @@ def _numba_equi_join_range_join(
return l_index, r_index


@njit
def _numba_less_than(arr: np.ndarray, value: Any):
"""
Get earliest position in `arr`
where arr[i] <= `value`
"""
min_idx = 0
max_idx = len(arr)
while min_idx < max_idx:
# to avoid overflow
mid_idx = min_idx + ((max_idx - min_idx) >> 1)
_mid_idx = np.uint64(mid_idx)
if arr[_mid_idx] < value:
min_idx = mid_idx + 1
else:
max_idx = mid_idx
# it is greater than
# the max value in the array
if min_idx == len(arr):
return -1
return min_idx


def _numba_single_non_equi_join(
left: pd.Series, right: pd.Series, op: str, keep: str
) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -1175,6 +1152,31 @@ def _numba_multiple_non_equi_join(
return left_indices, right_indices


@njit
def _numba_less_than(arr: np.ndarray, value: Any):
"""
Get earliest position in `arr`
where arr[i] <= `value`.
Adapted from numba's internals.
"""
min_idx = 0
max_idx = len(arr)
while min_idx < max_idx:
# to avoid overflow
mid_idx = min_idx + ((max_idx - min_idx) >> 1)
_mid_idx = np.uint64(mid_idx)
if arr[_mid_idx] < value:
min_idx = mid_idx + 1
else:
max_idx = mid_idx
# it is greater than
# the max value in the array
if min_idx == len(arr):
return -1
return min_idx


@njit(cache=True)
def _numba_non_equi_join_not_monotonic_keep_all(
left_regions: np.ndarray,
Expand Down
40 changes: 39 additions & 1 deletion janitor/functions/conditional_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def conditional_join(
use_numba: bool = False,
indicator: Optional[Union[bool, str]] = False,
force: bool = False,
aggfunc: dict = None,
) -> pd.DataFrame:
"""The conditional_join function operates similarly to `pd.merge`,
but supports joins on inequality operators,
Expand Down Expand Up @@ -265,7 +266,13 @@ def conditional_join(
`right_only` for observations whose merge key
only appears in the right DataFrame, and `both` if the observation’s
merge key is found in both DataFrames.
force: If `True`, force the non-equi join conditions to execute before the equi join.
force: If `True`, force the non-equi join conditions to execute
before the equi join.
aggfunc: Compute aggregation of values in the right dataframe,
for each index in the left dataframe.
If passed, the key is the column name to aggregate on,
while the value is the aggregation function.
Aggregations currently supported are sum, subtract, prod, size, count, min, max.
Returns:
Expand All @@ -283,6 +290,7 @@ def conditional_join(
use_numba=use_numba,
indicator=indicator,
force=force,
aggfunc=aggfunc,
)


Expand Down Expand Up @@ -314,6 +322,7 @@ def _conditional_join_preliminary_checks(
force: bool,
return_matching_indices: bool = False,
return_ragged_arrays: bool = False,
aggfunc: dict = None,
) -> tuple:
"""
Preliminary checks for conditional_join are conducted here.
Expand Down Expand Up @@ -399,6 +408,31 @@ def _conditional_join_preliminary_checks(

check("return_ragged_arrays", return_ragged_arrays, [bool])

if aggfunc:
if not use_numba:
raise ValueError("aggfunc works only when use_numba=True")
check("aggfunc", aggfunc, [dict])
dtypes = right.dtypes
funcs = {"sum", "subtract", "prod", "size", "count", "min", "max"}
for key, value in aggfunc.items():
_dtype = dtypes[key]
if (
not is_numeric_dtype(_dtype)
and not is_datetime64_dtype(_dtype)
and not is_timedelta64_dtype(_dtype)
):
raise TypeError(
"The aggregation column should be a "
"numeric or datetime or timedelta dtype; "
f"instead got {_dtype}"
)
if value not in funcs:
raise ValueError(
"The aggregation function should be one of "
f"{', '.join(funcs),}; "
f"instead got {value}"
)

return (
df,
right,
Expand All @@ -411,6 +445,7 @@ def _conditional_join_preliminary_checks(
indicator,
force,
return_ragged_arrays,
aggfunc,
)


Expand Down Expand Up @@ -463,6 +498,7 @@ def _conditional_join_compute(
force: bool,
return_matching_indices: bool = False,
return_ragged_arrays: bool = False,
aggfunc: bool = None,
) -> pd.DataFrame:
"""
This is where the actual computation
Expand All @@ -481,6 +517,7 @@ def _conditional_join_compute(
indicator,
force,
return_ragged_arrays,
aggfunc,
) = _conditional_join_preliminary_checks(
df=df,
right=right,
Expand All @@ -494,6 +531,7 @@ def _conditional_join_compute(
force=force,
return_matching_indices=return_matching_indices,
return_ragged_arrays=return_ragged_arrays,
aggfunc=aggfunc,
)
eq_check = False
le_lt_check = False
Expand Down
54 changes: 54 additions & 0 deletions tests/functions/test_conditional_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,60 @@ def test_check_how_value(dummy, series):
dummy.conditional_join(series, ("id", "B", "<"), how="INNER")


def test_check_aggfunc_value(dummy, series):
"""
Raise ValueError if aggfunc value is not one of
"sum", "subtract", "prod", "size", "count", 'min', 'max'.
"""
with pytest.raises(
ValueError, match="The aggregation function should be one of.+"
):
dummy.conditional_join(
series, ("id", "B", "<"), use_numba=True, aggfunc={"B": "rar"}
)


def test_check_aggfunc_use_numba(dummy, series):
"""
Raise ValueError if aggfunc and use_numba is False.
"""
with pytest.raises(
ValueError, match="aggfunc works only when use_numba=True"
):
dummy.conditional_join(series, ("id", "B", "<"), aggfunc={"B": "sum"})


def test_check_aggfunc_type(dummy, series):
"""
Raise if aggfunc is not a dict.
"""
with pytest.raises(TypeError, match="aggfunc should be one of.+"):
dummy.conditional_join(
series, ("id", "B", "<"), use_numba=True, aggfunc="size"
)


def test_check_aggfunc_key(dummy, series):
"""
Raise if aggfunc key is not in right.
"""
with pytest.raises(KeyError):
dummy.conditional_join(
series, ("id", "B", "<"), use_numba=True, aggfunc={"C": "sum"}
)


def test_check_aggfunc_dtype(dummy, series):
"""
Raise if aggfunc dtype is not the right type.
"""
series = pd.DataFrame(series).assign(C="rar")
with pytest.raises(TypeError, match="The aggregation column should be.+"):
dummy.conditional_join(
series, ("id", "B", "<"), use_numba=True, aggfunc={"C": "sum"}
)


def test_df_columns(dummy):
"""
Raise TypeError if `df_columns`is a dictionary,
Expand Down

0 comments on commit af7ef10

Please sign in to comment.