Skip to content

Commit

Permalink
Allow use of Schema hooks on OneOfSchema
Browse files Browse the repository at this point in the history
Rather than overriding the `dump()` and `load()` methods of the Schema
class, override `_serialize` and `_deserialize`, which are the
"concrete" steps in schema loading and dumping which handle loading or
dumping on a field-by-field basis.

This still uses load() and dump() methods of the type schemas being
used, but it happens between the various hooks which may run on the
OneOfSchema instance.

Add a test that checks that a `post_dump` hook to remove the `type`
field works.

The most significant downside of this change is that it makes use of
several private APIs within marshmallow. Not only are `_serialize` and
`_deserialize` private methods, but the error_store object which is
used here is also considered private (per marshmallow docs).

In order to better guarantee behavior near-identical to marshmallow,
several methods from marshmallow.utils have been copied in-tree here.

One notable API change here is that arbitrary keyword arguments are no
longer being passed from `OneOfSchema.load()` and `OneOfSchema.dump()`
down into the type schemas' load and dump methods. As a result, you
cannot specify a load or dump parameter here and expect it to take
effect.
With the switch to overriding `_serialize` and `_deserialize`, there
is no practical way to pass parameters like that.

closes marshmallow-code#4
  • Loading branch information
sirosen committed Nov 10, 2020
1 parent 287ddd8 commit 6f040d3
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 71 deletions.
155 changes: 84 additions & 71 deletions marshmallow_oneofschema/one_of_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
from marshmallow import Schema, ValidationError
from collections.abc import Mapping
import inspect

from marshmallow import Schema, ValidationError, RAISE


# these helpers copied from marshmallow.utils #


def is_generator(obj) -> bool:
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)


def is_iterable_but_not_string(obj) -> bool:
"""Return True if ``obj`` is an iterable object that isn't a string."""
return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)


def is_collection(obj) -> bool:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)


# end of helpers copied from marshmallow.utils #


class OneOfSchema(Schema):
Expand Down Expand Up @@ -63,32 +87,16 @@ def get_obj_type(self, obj):
"""Returns name of object schema"""
return obj.__class__.__name__

def dump(self, obj, *, many=None, **kwargs):
errors = {}
result_data = []
result_errors = {}
many = self.many if many is None else bool(many)
if not many:
result = result_data = self._dump(obj, **kwargs)
else:
for idx, o in enumerate(obj):
try:
result = self._dump(o, **kwargs)
result_data.append(result)
except ValidationError as error:
result_errors[idx] = error.normalized_messages()
result_data.append(error.valid_data)

result = result_data
errors = result_errors

if not errors:
return result
else:
exc = ValidationError(errors, data=obj, valid_data=result)
raise exc

def _dump(self, obj, *, update_fields=True, **kwargs):
# override the `_serialize` method of Schema, rather than `dump`
# this requires that we interact with a private API of marshmallow, but
# `_serialize` is the step that happens between pre_dump and post_dump
# hooks, so by using this rather than `load()`, we get schema hooks to work
def _serialize(self, obj, *, many=False):
if many and obj is not None:
return [self._serialize(subdoc, many=False) for subdoc in obj]
return self._dump_type_schema(obj)

def _dump_type_schema(self, obj):
obj_type = self.get_obj_type(obj)
if not obj_type:
return (
Expand All @@ -104,46 +112,58 @@ def _dump(self, obj, *, update_fields=True, **kwargs):

schema.context.update(getattr(self, "context", {}))

result = schema.dump(obj, many=False, **kwargs)
result = schema.dump(obj, many=False)
if result is not None:
result[self.type_field] = obj_type
return result

def load(self, data, *, many=None, partial=None, unknown=None, **kwargs):
errors = {}
result_data = []
result_errors = {}
many = self.many if many is None else bool(many)
if partial is None:
partial = self.partial
if not many:
try:
result = result_data = self._load(
data, partial=partial, unknown=unknown, **kwargs
)
# result_data.append(result)
except ValidationError as error:
result_errors = error.normalized_messages()
result_data.append(error.valid_data)
else:
for idx, item in enumerate(data):
try:
result = self._load(item, partial=partial, **kwargs)
result_data.append(result)
except ValidationError as error:
result_errors[idx] = error.normalized_messages()
result_data.append(error.valid_data)

result = result_data
errors = result_errors

if not errors:
return result
else:
exc = ValidationError(errors, data=data, valid_data=result)
raise exc

def _load(self, data, *, partial=None, unknown=None, **kwargs):
# override the `_deserialize` method of Schema, rather than `load`
# this requires that we interact with a private API of marshmallow, but
# `_deserialize` is the step that happens between pre_load and validation
# hooks, so by using this rather than `load()`, we get schema hooks to work
def _deserialize(
self,
data,
*,
error_store,
many=False,
partial=False,
unknown=RAISE,
index=None,
):
index = index if self.opts.index_errors else None
# if many, check for non-collection data (error) or iterate and
# re-invoke `_deserialize` on each one with many=False
# this is paraphrased from marshmallow.Schema._deserialize
if many:
if not is_collection(data):
error_store.store_error([self.error_messages["type"]], index=index)
return []
else:
return [
self._deserialize(
subdoc,
error_store=error_store,
many=False,
partial=partial,
unknown=unknown,
index=idx,
)
for idx, subdoc in enumerate(data)
]
if not isinstance(data, Mapping):
error_store.store_error([self.error_messages["type"]], index=index)
return self.dict_class()

try:
result = self._load_type_schema(data, partial=partial, unknown=unknown)
except ValidationError as err:
error_store.store_error(err.messages, index=index)
result = err.valid_data

return result

def _load_type_schema(self, data, *, partial=None, unknown=None):
if not isinstance(data, dict):
raise ValidationError({"_schema": "Invalid data type: %s" % data})

Expand Down Expand Up @@ -173,11 +193,4 @@ def _load(self, data, *, partial=None, unknown=None, **kwargs):

schema.context.update(getattr(self, "context", {}))

return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

def validate(self, data, *, many=None, partial=None):
try:
self.load(data, many=many, partial=partial)
except ValidationError as ve:
return ve.messages
return {}
return schema.load(data, many=False, partial=partial, unknown=unknown)
23 changes: 23 additions & 0 deletions tests/test_one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,29 @@ class TestSchema(OneOfSchema):
TestSchema(unknown="exclude").load({"type": "Bar", "bar": 123})
assert Nonlocal.data["type"] == "Bar"

def test_post_dump_remove_type_field(self):
# test using a @post_dump hook to remove the type field which
# OneOfSchema will add to the data by default

# define a schema without post_dump
class MySchemaVariant1(OneOfSchema):
type_schemas = {"Foo": FooSchema, "Bar": BarSchema}

# and a variant with post_dump
class MySchemaVariant2(MySchemaVariant1):
@m.post_dump
def remove_type_field(self, data, **kwargs):
del data["type"]
return data

# sanity check: `type` should be present in a dump from Variant1
assert MySchemaVariant1().dump(Foo("someval")) == {
"type": "Foo",
"value": "someval",
}
# now check that the post_dump hook fired
assert MySchemaVariant2().dump(Foo("someval")) == {"value": "someval"}

def test_load_non_dict(self):
with pytest.raises(m.ValidationError) as exc_info:
MySchema().load(123)
Expand Down

0 comments on commit 6f040d3

Please sign in to comment.