Skip to content

Commit

Permalink
Replace function registries with catalogue (explosion#4584)
Browse files Browse the repository at this point in the history
* Replace functions registries with catalogue

* Update __init__.py

* Fix test

* Revert unrelated flag [ci skip]
  • Loading branch information
ines authored Nov 7, 2019
1 parent 0f8678c commit 09cec3e
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 140 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ blis>=0.4.0,<0.5.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third party dependencies
numpy>=1.15.0
requests>=2.13.0,<3.0.0
plac>=0.9.6,<1.2.0
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"
# Optional dependencies
jsonschema>=2.6.0,<3.1.0
# Development dependencies
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ install_requires =
blis>=0.4.0,<0.5.0
wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0
catalogue>=0.0.7,<1.1.0
# Third-party dependencies
setuptools
numpy>=1.15.0
plac>=0.9.6,<1.2.0
requests>=2.13.0,<3.0.0
pathlib==1.0.1; python_version < "3.4"
importlib_metadata>=0.20; python_version < "3.8"

[options.extras_require]
lookups =
Expand Down
2 changes: 1 addition & 1 deletion spacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .about import __version__
from .errors import Errors, Warnings, deprecation_warning
from . import util
from .util import register_architecture, get_architecture
from .util import registry
from .language import component


Expand Down
5 changes: 0 additions & 5 deletions spacy/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@
except ImportError:
cupy = None

try: # Python 3.8
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata # noqa: F401

try:
from thinc.neural.optimizers import Optimizer # noqa: F401
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions spacy/displacy/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .templates import TPL_DEP_SVG, TPL_DEP_WORDS, TPL_DEP_ARCS, TPL_ENTS
from .templates import TPL_ENT, TPL_ENT_RTL, TPL_FIGURE, TPL_TITLE, TPL_PAGE
from ..util import minify_html, escape_html, get_entry_points, ENTRY_POINTS
from ..util import minify_html, escape_html, registry
from ..errors import Errors


Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, options={}):
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
user_colors = get_entry_points(ENTRY_POINTS.displacy_colors)
user_colors = registry.displacy_colors.get_all()
for user_color in user_colors.values():
colors.update(user_color)
colors.update(options.get("colors", {}))
Expand Down
6 changes: 3 additions & 3 deletions spacy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def create_lookups(cls, nlp=None):
filenames = {name: root / filename for name, filename in cls.resources}
if LANG in cls.lex_attr_getters:
lang = cls.lex_attr_getters[LANG](None)
user_lookups = util.get_entry_point(util.ENTRY_POINTS.lookups, lang, {})
filenames.update(user_lookups)
if lang in util.registry.lookups:
filenames.update(util.registry.lookups.get(lang))
lookups = Lookups()
for name, filename in filenames.items():
data = util.load_language_data(filename)
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
100,000 characters in one text.
RETURNS (Language): The newly constructed object.
"""
user_factories = util.get_entry_points(util.ENTRY_POINTS.factories)
user_factories = util.registry.factories.get_all()
self.factories.update(user_factories)
self._meta = dict(meta)
self._path = None
Expand Down
6 changes: 3 additions & 3 deletions spacy/ml/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from thinc.api import chain
from thinc.v2v import Maxout
from thinc.misc import LayerNorm
from ..util import register_architecture, make_layer
from ..util import registry, make_layer


@register_architecture("thinc.FeedForward.v1")
@registry.architectures.register("thinc.FeedForward.v1")
def FeedForward(config):
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
model = chain(*layers)
model.cfg = config
return model


@register_architecture("spacy.LayerNormalizedMaxout.v1")
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(config):
width = config["width"]
pieces = config["pieces"]
Expand Down
18 changes: 9 additions & 9 deletions spacy/ml/tok2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual, LayerNorm, FeatureExtracter
from ..util import make_layer, register_architecture
from ..util import make_layer, registry
from ._wire import concatenate_lists


@register_architecture("spacy.Tok2Vec.v1")
@registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
Expand All @@ -24,13 +24,13 @@ def Tok2Vec(config):
return tok2vec


@register_architecture("spacy.Doc2Feats.v1")
@registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config):
columns = config["columns"]
return FeatureExtracter(columns)


@register_architecture("spacy.MultiHashEmbed.v1")
@registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config):
# For backwards compatibility with models before the architecture registry,
# we have to be careful to get exactly the same model structure. One subtle
Expand Down Expand Up @@ -78,7 +78,7 @@ def MultiHashEmbed(config):
return layer


@register_architecture("spacy.CharacterEmbed.v1")
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
from .. import _ml

Expand All @@ -94,7 +94,7 @@ def CharacterEmbed(config):
return model


@register_architecture("spacy.MaxoutWindowEncoder.v1")
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
Expand All @@ -110,7 +110,7 @@ def MaxoutWindowEncoder(config):
return model


@register_architecture("spacy.MishWindowEncoder.v1")
@registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config):
from thinc.v2v import Mish

Expand All @@ -124,12 +124,12 @@ def MishWindowEncoder(config):
return model


@register_architecture("spacy.PretrainedVectors.v1")
@registry.architectures.register("spacy.PretrainedVectors.v1")
def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"])


@register_architecture("spacy.TorchBiLSTMEncoder.v1")
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN
Expand Down
19 changes: 19 additions & 0 deletions spacy/tests/test_architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding: utf8
from __future__ import unicode_literals

import pytest
from spacy import registry
from thinc.v2v import Affine
from catalogue import RegistryError


@registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out)


def test_get_architecture():
arch = registry.architectures.get("my_test_function")
assert arch is create_model
with pytest.raises(RegistryError):
registry.architectures.get("not_an_existing_key")
19 changes: 0 additions & 19 deletions spacy/tests/test_register_architecture.py

This file was deleted.

113 changes: 17 additions & 96 deletions spacy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools
import numpy.random
import srsly
import catalogue
import sys

try:
Expand All @@ -27,29 +28,20 @@

from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, unicode_
from .compat import import_file, importlib_metadata
from .compat import import_file
from .errors import Errors, Warnings, deprecation_warning


LANGUAGES = {}
ARCHITECTURES = {}
_data_path = Path(__file__).parent / "data"
_PRINT_ENV = False


# NB: Ony ever call this once! If called more than ince within the
# function, test_issue1506 hangs and it's not 100% clear why.
AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points()


class ENTRY_POINTS(object):
"""Available entry points to register extensions."""

factories = "spacy_factories"
languages = "spacy_languages"
displacy_colors = "spacy_displacy_colors"
lookups = "spacy_lookups"
architectures = "spacy_architectures"
class registry(object):
languages = catalogue.create("spacy", "languages", entry_points=True)
architectures = catalogue.create("spacy", "architectures", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True)
factories = catalogue.create("spacy", "factories", entry_points=True)
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)


def set_env_log(value):
Expand All @@ -65,8 +57,7 @@ def lang_class_is_loaded(lang):
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (bool): Whether a Language class has been loaded.
"""
global LANGUAGES
return lang in LANGUAGES
return lang in registry.languages


def get_lang_class(lang):
Expand All @@ -75,19 +66,16 @@ def get_lang_class(lang):
lang (unicode): Two-letter language code, e.g. 'en'.
RETURNS (Language): Language class.
"""
global LANGUAGES
# Check if an entry point is exposed for the language code
entry_point = get_entry_point(ENTRY_POINTS.languages, lang)
if entry_point is not None:
LANGUAGES[lang] = entry_point
return entry_point
if lang not in LANGUAGES:
# Check if language is registered / entry point is available
if lang in registry.languages:
return registry.languages.get(lang)
else:
try:
module = importlib.import_module(".lang.%s" % lang, "spacy")
except ImportError as err:
raise ImportError(Errors.E048.format(lang=lang, err=err))
LANGUAGES[lang] = getattr(module, module.__all__[0])
return LANGUAGES[lang]
set_lang_class(lang, getattr(module, module.__all__[0]))
return registry.languages.get(lang)


def set_lang_class(name, cls):
Expand All @@ -96,8 +84,7 @@ def set_lang_class(name, cls):
name (unicode): Name of Language class.
cls (Language): Language class.
"""
global LANGUAGES
LANGUAGES[name] = cls
registry.languages.register(name, func=cls)


def get_data_path(require_exists=True):
Expand All @@ -121,49 +108,11 @@ def set_data_path(path):
_data_path = ensure_path(path)


def register_architecture(name, arch=None):
"""Decorator to register an architecture. An architecture is a function
that returns a Thinc Model object.
name (unicode): The name of the architecture to register.
arch (Model): Optional architecture if function is called directly and
not used as a decorator.
RETURNS (callable): Function to register architecture.
"""
global ARCHITECTURES
if arch is not None:
ARCHITECTURES[name] = arch
return arch

def do_registration(arch):
ARCHITECTURES[name] = arch
return arch

return do_registration


def make_layer(arch_config):
arch_func = get_architecture(arch_config["arch"])
arch_func = registry.architectures.get(arch_config["arch"])
return arch_func(arch_config["config"])


def get_architecture(name):
"""Get a model architecture function by name. Raises a KeyError if the
architecture is not found.
name (unicode): The mame of the architecture.
RETURNS (Model): The architecture.
"""
# Check if an entry point is exposed for the architecture code
entry_point = get_entry_point(ENTRY_POINTS.architectures, name)
if entry_point is not None:
ARCHITECTURES[name] = entry_point
if name not in ARCHITECTURES:
names = ", ".join(sorted(ARCHITECTURES.keys()))
raise KeyError(Errors.E174.format(name=name, names=names))
return ARCHITECTURES[name]


def ensure_path(path):
"""Ensure string is converted to a Path.
Expand Down Expand Up @@ -327,34 +276,6 @@ def get_package_path(name):
return Path(pkg.__file__).parent


def get_entry_points(key):
"""Get registered entry points from other packages for a given key, e.g.
'spacy_factories' and return them as a dictionary, keyed by name.
key (unicode): Entry point name.
RETURNS (dict): Entry points, keyed by name.
"""
result = {}
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
result[entry_point.name] = entry_point.load()
return result


def get_entry_point(key, value, default=None):
"""Check if registered entry point is available for a given name and
load it. Otherwise, return None.
key (unicode): Entry point name.
value (unicode): Name of entry point to load.
default: Optional default value to return.
RETURNS: The loaded entry point or None.
"""
for entry_point in AVAILABLE_ENTRY_POINTS.get(key, []):
if entry_point.name == value:
return entry_point.load()
return default


def is_in_jupyter():
"""Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer.
Expand Down

0 comments on commit 09cec3e

Please sign in to comment.