diff --git a/piecash/sa_extra.py b/piecash/sa_extra.py index 49f148f..a120822 100644 --- a/piecash/sa_extra.py +++ b/piecash/sa_extra.py @@ -18,19 +18,44 @@ ) from sqlalchemy.dialects import sqlite from sqlalchemy.ext.compiler import compiles -from sqlalchemy.ext.declarative import as_declarative from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import sessionmaker, object_session -# import yaml +try: + # sqlalchemy 1.4 and greater changes `as_declarative` to create a + # `registry` object, on which `as_declarative_base` is called. + # + # For unclear reasons, the `constructor` keyword is not forwarded + # to the `registry` constructor in `as_declarative`, even though + # it's defined for `registry.__init__` and not allowed by + # `registry.as_declarative_base`. + # + # This redefinition of `as_declarative` passes the `constructor` keyword + # onto `registry`, just as is done in `declarative_base`. + # + # I am using the existence of `sqlalchemy.orm.registry` (new in SA 1.4) + # as the marker for whether `constructor` is supported in `as_declarative` + from sqlalchemy.orm import registry + from sqlalchemy.orm.decl_base import _declarative_constructor + + def as_declarative(**kw): + bind, metadata, class_registry, constructor = ( + kw.pop("bind", None), + kw.pop("metadata", None), + kw.pop("class_registry", None), + kw.pop("constructor", _declarative_constructor), + ) + + return registry(_bind=bind, metadata=metadata, class_registry=class_registry, constructor=constructor).as_declarative_base(**kw) + +except ImportError: + # `as_declarative` was under `sqlalchemy.ext.declarative` prior to 1.4 + from sqlalchemy.ext.declarative import as_declarative def __init__blocked(self, *args, **kwargs): - raise NotImplementedError( - "Objects of type {} cannot be created from scratch " - "(only read)".format(self.__class__.__name__) - ) + raise NotImplementedError("Objects of type {} cannot be created from scratch " "(only read)".format(self.__class__.__name__)) @as_declarative(constructor=__init__blocked) @@ -103,15 +128,9 @@ def load_dialect_impl(self, dialect): def process_bind_param(self, value, dialect): if value is not None: - assert isinstance( - value, datetime.datetime - ), "value {} is not of type datetime.datetime but type {}".format( - value, type(value) - ) + assert isinstance(value, datetime.datetime), "value {} is not of type datetime.datetime but type {}".format(value, type(value)) if value.microsecond != 0: - logging.warning( - "A datetime has been given with microseconds which are not saved in the database" - ) + logging.warning("A datetime has been given with microseconds which are not saved in the database") if not value.tzinfo: value = tz.localize(value) @@ -145,15 +164,11 @@ def process_bind_param(self, value, dialect): if value is not None: assert isinstance(value, datetime.date) and not isinstance( value, datetime.datetime - ), "value {} is not of type datetime.date but type {}".format( - value, type(value) - ) + ), "value {} is not of type datetime.date but type {}".format(value, type(value)) if self.neutral_time: result = datetime.datetime.combine(value, datetime.time(10, 59, 0)) else: - result = tz.localize( - datetime.datetime.combine(value, datetime.time(0, 0, 0)) - ).astimezone(utc) + result = tz.localize(datetime.datetime.combine(value, datetime.time(0, 0, 0))).astimezone(utc) return result.replace(tzinfo=None) def process_result_value(self, value, dialect): @@ -205,9 +220,7 @@ def expr(cls): ) -def pure_slot_property( - slot_name, slot_transform=lambda x: x, ignore_invalid_slot=False -): +def pure_slot_property(slot_name, slot_transform=lambda x: x, ignore_invalid_slot=False): """ Create a property (class must have slots) that maps to a slot @@ -323,11 +336,7 @@ def process_bind_param(self, value, dialect): return [k for k, v in self.choices.items() if v == value][0] except IndexError: # print("Value '{}' is not in [{}]".format(", ".join(self.choices.values()))) - raise ValueError( - "Value '{}' is not in choices [{}]".format( - value, ", ".join(self.choices.values()) - ) - ) + raise ValueError("Value '{}' is not in choices [{}]".format(value, ", ".join(self.choices.values()))) def process_result_value(self, value, dialect): return self.choices[value] diff --git a/setup.py b/setup.py index ebc6ec3..83eef3b 100644 --- a/setup.py +++ b/setup.py @@ -202,7 +202,7 @@ def _lint(): ## package dependencies install_requires = [ - "SQLAlchemy>=1.0, <1.4", + "SQLAlchemy>=1.0, <1.5", "SQLAlchemy-Utils!=0.36.8", "pytz", "tzlocal",