Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Commit

Permalink
Merge pull request #141 from xhochy/no-regex-case-sensitive-contains
Browse files Browse the repository at this point in the history
Add numba implementation for contains(.., regex=False, case=True)
  • Loading branch information
xhochy authored Jun 16, 2020
2 parents d58e414 + 8d2bbbf commit 2a93fca
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 13 deletions.
19 changes: 19 additions & 0 deletions benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@ def generate_test_array(n):
]


def generate_test_array_non_null(n):
return [six.text_type(x) + six.text_type(x) + six.text_type(x) for x in range(n)]


class TimeSuiteNonNull:
def setup(self):
array = generate_test_array_non_null(2 ** 17)
self.df = pd.DataFrame({"str": array})
self.df_ext = pd.DataFrame(
{"str": fr.FletcherChunkedArray(pa.array(array, pa.string()))}
)

def time_contains_no_regex(self):
self.df["str"].str.contains("0", regex=False)

def time_contains_no_regex_ext(self):
self.df_ext["str"].text.contains("0", regex=False)


class TimeSuite:
def setup(self):
array = generate_test_array(2 ** 17)
Expand Down
41 changes: 40 additions & 1 deletion fletcher/_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial, singledispatch
from functools import partial, singledispatch, wraps
from typing import Any, Callable, List, Optional, Tuple, Union

import numba
Expand Down Expand Up @@ -561,3 +561,42 @@ def _merge_valid_bitmaps(a: pa.Array, b: pa.Array) -> np.ndarray:
)

return result


def apply_per_chunk(func):
"""Apply a function to each chunk if the input is chunked."""

@wraps(func)
def wrapper(arr: Union[pa.Array, pa.ChunkedArray], *args, **kwargs):
if isinstance(arr, pa.ChunkedArray):
return pa.chunked_array(
[func(chunk, *args, **kwargs) for chunk in arr.chunks]
)
else:
return func(arr, *args, **kwargs)

return wrapper


@apply_per_chunk
def all_true_like(arr: pa.Array) -> pa.Array:
"""Return a boolean array with all-True with the same size as the input."""
valid_buffer = arr.buffers()[0]
if valid_buffer:
valid_buffer = valid_buffer.slice(arr.offset // 8)

output_offset = arr.offset % 8
output_length = len(arr) + output_offset

output_size = output_length // 8
if output_length % 8 > 0:
output_size += 1
output = np.full(output_size, fill_value=255, dtype=np.uint8)

return pa.Array.from_buffers(
pa.bool_(),
len(arr),
[valid_buffer, pa.py_buffer(output)],
arr.null_count,
output_offset,
)
154 changes: 154 additions & 0 deletions fletcher/algorithms/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from numba import njit

from fletcher._algorithms import (
_buffer_to_view,
_calculate_chunk_offsets,
_combined_in_chunk_offsets,
_merge_valid_bitmaps,
apply_per_chunk,
)


Expand Down Expand Up @@ -112,3 +114,155 @@ def _text_cat(a: pa.Array, b: pa.Array) -> pa.Array:
buffers = [pa.py_buffer(x) for x in [valid, result_offsets, result_data]]
return pa.Array.from_buffers(pa.string(), len(a), buffers)
return a


@njit
def _text_contains_case_sensitive_nonnull(
length: int, offsets: np.ndarray, data: np.ndarray, pat, output: np.ndarray
) -> None:
for row_idx in range(length):
str_len = offsets[row_idx + 1] - offsets[row_idx]

contains = False
for str_idx in range(max(0, str_len - len(pat) + 1)):
pat_found = True
for pat_idx in range(len(pat)):
if data[offsets[row_idx] + str_idx + pat_idx] != pat[pat_idx]:
pat_found = False
break
if pat_found:
contains = True
break

# TODO: Set word-wise for better performance
byte_offset_result = row_idx // 8
bit_offset_result = row_idx % 8
mask_result = np.uint8(1 << bit_offset_result)
current = output[byte_offset_result]
if contains: # must be logical, not bit-wise as different bits may be flagged
output[byte_offset_result] = current | mask_result
else:
output[byte_offset_result] = current & ~mask_result


@njit
def _text_contains_case_sensitive_nulls(
length: int,
valid_bits: np.ndarray,
valid_offset: int,
offsets: np.ndarray,
data: np.ndarray,
pat: bytes,
output: np.ndarray,
) -> None:
for row_idx in range(length):
# Check whether the current entry is null.
byte_offset = (row_idx + valid_offset) // 8
bit_offset = (row_idx + valid_offset) % 8
mask = np.uint8(1 << bit_offset)
valid = valid_bits[byte_offset] & mask

# We don't need to set the result for nulls, the calling code is
# already dealing with them by zero'ing the output.
if not valid:
continue

str_len = offsets[row_idx + 1] - offsets[row_idx]

contains = False
# Try to find the pattern at each starting position
for str_idx in range(max(0, str_len - len(pat) + 1)):
pat_found = True
# Compare at the current position byte-by-byte
for pat_idx in range(len(pat)):
if data[offsets[row_idx] + str_idx + pat_idx] != pat[pat_idx]:
pat_found = False
break
if pat_found:
contains = True
break

# Write out the result into the bit-mask
byte_offset_result = row_idx // 8
bit_offset_result = row_idx % 8
mask_result = np.uint8(1 << bit_offset_result)
current = output[byte_offset_result]
if contains: # must be logical, not bit-wise as different bits may be flagged
output[byte_offset_result] = current | mask_result
else:
output[byte_offset_result] = current & ~mask_result


@njit
def _shift_unaligned_bitmap(
valid_bits: np.ndarray, valid_offset: int, length: int, output: np.ndarray
) -> None:
for i in range(length):
byte_offset = (i + valid_offset) // 8
bit_offset = (i + valid_offset) % 8
mask = np.uint8(1 << bit_offset)
valid = valid_bits[byte_offset] & mask

byte_offset_result = i // 8
bit_offset_result = i % 8
mask_result = np.uint8(1 << bit_offset_result)
current = output[byte_offset_result]
if valid:
output[byte_offset_result] = current | mask_result


def shift_unaligned_bitmap(
valid_buffer: pa.Buffer, offset: int, length: int
) -> pa.Buffer:
"""Shift an unaligned bitmap to be offsetted at 0."""
output_size = length // 8
if length % 8 > 0:
output_size += 1
output = np.zeros(output_size, dtype=np.uint8)

_shift_unaligned_bitmap(valid_buffer, offset, length, output)

return pa.py_buffer(output)


@apply_per_chunk
def _text_contains_case_sensitive(data: pa.Array, pat: str) -> pa.Array:
"""
Check for each element in the data whether it contains the pattern ``pat``.
This implementation does basic byte-by-byte comparison and is independent
of any locales or encodings.
"""
# Convert to UTF-8 bytes
pat_bytes: bytes = pat.encode()

# Initialise boolean (bit-packaed) output array.
output_size = len(data) // 8
if len(data) % 8 > 0:
output_size += 1
output = np.empty(output_size, dtype=np.uint8)
if len(data) % 8 > 0:
# Zero trailing bits
output[-1] = 0

offsets, data_buffer = _extract_string_buffers(data)

if data.null_count == 0:
valid_buffer = None
_text_contains_case_sensitive_nonnull(
len(data), offsets, data_buffer, pat_bytes, output
)
else:
valid = _buffer_to_view(data.buffers()[0])
_text_contains_case_sensitive_nulls(
len(data), valid, data.offset, offsets, data_buffer, pat_bytes, output
)
valid_buffer = data.buffers()[0].slice(data.offset // 8)
if data.offset % 8 != 0:
valid_buffer = shift_unaligned_bitmap(
valid_buffer, data.offset % 8, len(data)
)

return pa.Array.from_buffers(
pa.bool_(), len(data), [valid_buffer, pa.py_buffer(output)], data.null_count
)
36 changes: 30 additions & 6 deletions fletcher/string_array.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Optional
from typing import Optional, Union

import numpy as np
import pandas as pd
import pyarrow as pa

from ._algorithms import _endswith, _startswith
from ._algorithms import _endswith, _startswith, all_true_like
from ._numba_compat import NumbaString, NumbaStringArray
from .algorithms.string import _text_cat, _text_cat_chunked, _text_cat_chunked_mixed
from .algorithms.string import (
_text_cat,
_text_cat_chunked,
_text_cat_chunked_mixed,
_text_contains_case_sensitive,
)
from .base import FletcherBaseArray, FletcherChunkedArray, FletcherContinuousArray


Expand Down Expand Up @@ -56,12 +61,19 @@ def cat(self, others: Optional[FletcherBaseArray]) -> pd.Series:

def _call_str_accessor(self, func, *args, **kwargs) -> pd.Series:
pd_series = self.data.to_pandas()
result = pa.array(getattr(pd_series.str, func)(*args, **kwargs).values)
return self._series_like(
pa.array(getattr(pd_series.str, func)(*args, **kwargs).values)
)

def _series_like(self, array: Union[pa.Array, pa.ChunkedArray]) -> pd.Series:
"""Return an Arrow result as a series with the same base classes as the input."""
return pd.Series(
type(self.obj.values)(result), dtype=type(self.obj.dtype)(result.type)
type(self.obj.values)(array),
dtype=type(self.obj.dtype)(array.type),
index=self.obj.index,
)

def contains(self, pat, case=True, regex=True):
def contains(self, pat: str, case: bool = True, regex: bool = True) -> pd.Series:
"""
Test if pattern or regex is contained within a string of a Series or Index.
Expand Down Expand Up @@ -90,6 +102,18 @@ def contains(self, pat, case=True, regex=True):
given pattern is contained within the string of each element
of the Series or Index.
"""
if not regex:
if len(pat) == 0:
# For an empty pattern return all-True array
return self._series_like(all_true_like(self.data))

if case:
# Can just check for a match on the byte-sequence
return self._series_like(_text_contains_case_sensitive(self.data, pat))
else:
# Check if pat is all-ascii, then use lookup-table for lowercasing
# else: use libutf8proc
pass
return self._call_str_accessor("contains", pat=pat, case=case, regex=regex)

def zfill(self, width: int) -> pd.Series:
Expand Down
32 changes: 26 additions & 6 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
([], ""),
(["a", "b"], ""),
(["aa", "ab", "ba"], "a"),
(["aa", "ab", "ba", None], "a"),
(["aa", "ab", "ba", None], "A"),
(["aa", "ab", "bA", None], "a"),
(["aa", "AB", "ba", None], "A"),
(["aa", "ab", "ba", "bb", None], "a"),
(["aa", "ab", "ba", "bb", None], "A"),
(["aa", "ab", "bA", "bB", None], "a"),
(["aa", "AB", "ba", "BB", None], "A"),
],
)


def _fr_series_from_data(data, fletcher_variant):
arrow_data = pa.array(data, type=pa.string())
def _fr_series_from_data(data, fletcher_variant, dtype=pa.string()):
arrow_data = pa.array(data, type=dtype)
if fletcher_variant == "chunked":
fr_array = fr.FletcherChunkedArray(arrow_data)
else:
Expand Down Expand Up @@ -82,6 +82,26 @@ def test_contains_no_regex(data, pat, fletcher_variant):
_check_str_to_bool("contains", data, fletcher_variant, pat=pat, regex=False)


@pytest.mark.parametrize(
"data, pat, expected",
[
([], "", []),
(["a", "b"], "", [True, True]),
(["aa", "Ab", "ba", "bb", None], "a", [True, False, True, False, None]),
],
)
def test_contains_no_regex_ascii(data, pat, expected, fletcher_variant):
fr_series = _fr_series_from_data(data, fletcher_variant)
fr_expected = _fr_series_from_data(expected, fletcher_variant, pa.bool_())

# Run over slices to check offset handling code
for i in range(len(data)):
ser = fr_series.tail(len(data) - i)
expected = fr_expected.tail(len(data) - i)
result = ser.fr_text.contains(pat, regex=False)
tm.assert_series_equal(result, expected)


@string_patterns
def test_contains_no_regex_ignore_case(data, pat, fletcher_variant):
_check_str_to_bool(
Expand Down

0 comments on commit 2a93fca

Please sign in to comment.