diff --git a/aio_pika/patterns/rpc.py b/aio_pika/patterns/rpc.py index bbec383e..2baad719 100644 --- a/aio_pika/patterns/rpc.py +++ b/aio_pika/patterns/rpc.py @@ -399,18 +399,36 @@ async def unregister(self, func): self.routes.pop(queue.name) +class JsonRPCError(Exception): + def __init__(self, exc_type, message=None, *args): + super(JsonRPCError, self).__init__(message, args) + self.exc_type = exc_type + + class JsonRPC(RPC): SERIALIZER = json CONTENT_TYPE = "application/json" def serialize(self, data: Any) -> bytes: - return self.SERIALIZER.dumps(data, ensure_ascii=False, default=repr) + return self.SERIALIZER.dumps( + data, ensure_ascii=False, default=repr).encode('ascii') + + def deserialize(self, data: Any) -> bytes: + res = super().deserialize(data) + if isinstance(res, dict) and "error" in res: + res = JsonRPCError(res['error']['type'], + res['error']['message'], + res['error']['args']) + return res def serialize_exception(self, exception: Exception) -> bytes: return self.serialize( { "error": { - "type": exception.__class__.__name__, + "type": f'{exception.__module__}.' + f'{exception.__class__.__name__}' + if hasattr(exception, '__module__') + else exception.__class__.__name__, "message": repr(exception), "args": exception.args, }, diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 8582489e..caf5dd3c 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -2,12 +2,12 @@ import logging import pytest - +import inspect import aio_pika from aio_pika import Message from aio_pika.exceptions import DeliveryError from aio_pika.message import IncomingMessage -from aio_pika.patterns.rpc import RPC +from aio_pika.patterns.rpc import RPC, JsonRPC, JsonRPCError from aio_pika.patterns.rpc import log as rpc_logger from tests import get_random_name @@ -19,6 +19,14 @@ def rpc_func(*, foo, bar): return {"foo": "bar"} +class CustomException(Exception): + pass + + +def rpc_raise_exception(*, foo, bar): + raise CustomException('foo bar') + + class TestCase: async def test_simple(self, channel: aio_pika.Channel): rpc = await RPC.create(channel, auto_delete=True) @@ -172,3 +180,45 @@ async def test_register_twice(self, channel: aio_pika.Channel): await rpc.unregister(rpc_func) await rpc.close() + + async def test_jsonrpc_simple(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc", rpc_func, auto_delete=True) + + result = await rpc.proxy.test.rpc(foo=None, bar=None) + assert result == {"foo": "bar"} + + await rpc.unregister(rpc_func) + await rpc.close() + + # Close already closed + await rpc.close() + + async def test_jsonrpc_assert(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc", rpc_func, auto_delete=True) + + with pytest.raises(JsonRPCError) as excinfo: + await rpc.proxy.test.rpc(foo=True, bar=None) + assert excinfo.value.exc_type == 'AssertionError' + + await rpc.unregister(rpc_func) + await rpc.close() + + async def test_jsonrpc_error(self, channel: aio_pika.Channel): + rpc = await JsonRPC.create(channel, auto_delete=True) + + await rpc.register("test.rpc_error", rpc_raise_exception, + auto_delete=True) + + with pytest.raises(Exception) as excinfo: + await rpc.proxy.test.rpc_error(foo=True, bar=None) + assert excinfo.value.exc_type == 'tests.test_rpc.CustomException' + + with pytest.raises(JsonRPCError): + await rpc.proxy.test.rpc_error(foo=True, bar=None) + + await rpc.unregister(rpc_raise_exception) + await rpc.close()