Skip to content

Commit

Permalink
Flush tokenizer cache when necessary (explosion#4258)
Browse files Browse the repository at this point in the history
Flush tokenizer cache when affixes, token_match, or special cases are
modified.

Fixes explosion#4238, same issue as in explosion#1250.
  • Loading branch information
adrianeboyd authored and honnibal committed Sep 8, 2019
1 parent d03401f commit 3780e2f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
1 change: 0 additions & 1 deletion spacy/tests/regression/test_issue1001-1500.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from spacy.symbols import ORTH, LEMMA, POS, VERB, VerbForm_part


@pytest.mark.xfail
def test_issue1061():
'''Test special-case works after tokenizing. Was caching problem.'''
text = 'I like _MATH_ even _MATH_ when _MATH_, except when _MATH_ is _MATH_! but not _MATH_.'
Expand Down
8 changes: 4 additions & 4 deletions spacy/tokenizer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ cdef class Tokenizer:
cdef PreshMap _specials
cpdef readonly Vocab vocab

cdef public object token_match
cdef public object prefix_search
cdef public object suffix_search
cdef public object infix_finditer
cdef object _token_match
cdef object _prefix_search
cdef object _suffix_search
cdef object _infix_finditer
cdef object _rules

cpdef Doc tokens_from_list(self, list strings)
Expand Down
59 changes: 59 additions & 0 deletions spacy/tokenizer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,38 @@ cdef class Tokenizer:
for chunk, substrings in sorted(rules.items()):
self.add_special_case(chunk, substrings)

property token_match:
def __get__(self):
return self._token_match

def __set__(self, token_match):
self._token_match = token_match
self._flush_cache()

property prefix_search:
def __get__(self):
return self._prefix_search

def __set__(self, prefix_search):
self._prefix_search = prefix_search
self._flush_cache()

property suffix_search:
def __get__(self):
return self._suffix_search

def __set__(self, suffix_search):
self._suffix_search = suffix_search
self._flush_cache()

property infix_finditer:
def __get__(self):
return self._infix_finditer

def __set__(self, infix_finditer):
self._infix_finditer = infix_finditer
self._flush_cache()

def __reduce__(self):
args = (self.vocab,
self._rules,
Expand Down Expand Up @@ -141,9 +173,23 @@ cdef class Tokenizer:
for text in texts:
yield self(text)

def _flush_cache(self):
self._reset_cache([key for key in self._cache if not key in self._specials])

def _reset_cache(self, keys):
for k in keys:
del self._cache[k]
if not k in self._specials:
cached = <_Cached*>self._cache.get(k)
if cached is not NULL:
self.mem.free(cached)

def _reset_specials(self):
for k in self._specials:
cached = <_Cached*>self._specials.get(k)
del self._specials[k]
if cached is not NULL:
self.mem.free(cached)

cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
cached = <_Cached*>self._cache.get(key)
Expand Down Expand Up @@ -183,6 +229,9 @@ cdef class Tokenizer:
while string and len(string) != last_size:
if self.token_match and self.token_match(string):
break
if self._specials.get(hash_string(string)) != NULL:
has_special[0] = 1
break
last_size = len(string)
pre_len = self.find_prefix(string)
if pre_len != 0:
Expand Down Expand Up @@ -360,8 +409,15 @@ cdef class Tokenizer:
cached.is_lex = False
cached.data.tokens = self.vocab.make_fused_token(substrings)
key = hash_string(string)
stale_special = <_Cached*>self._specials.get(key)
stale_cached = <_Cached*>self._cache.get(key)
self._flush_cache()
self._specials.set(key, cached)
self._cache.set(key, cached)
if stale_special is not NULL:
self.mem.free(stale_special)
if stale_special != stale_cached and stale_cached is not NULL:
self.mem.free(stale_cached)
self._rules[string] = substrings

def to_disk(self, path, **kwargs):
Expand Down Expand Up @@ -444,7 +500,10 @@ cdef class Tokenizer:
if data.get("rules"):
# make sure to hard reset the cache to remove data from the default exceptions
self._rules = {}
self._reset_cache([key for key in self._cache])
self._reset_specials()
self._cache = PreshMap()
self._specials = PreshMap()
for string, substrings in data.get("rules", {}).items():
self.add_special_case(string, substrings)

Expand Down

0 comments on commit 3780e2f

Please sign in to comment.