Skip to content

Commit

Permalink
implemented api and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Feb 25, 2020
1 parent a3a1a6a commit 21d8505
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 28 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ async def main():
)

# Check if a particular exception was raised.
ValueError in response.exceptions
ValueError in response.exception_types
# >>> True

# To get all exceptions:
print(response.exceptions)
# >>> [ValueError()]

# To get all instances of a particular exception:
response.exceptions.get_all(ValueError)
response.exceptions_of_type(ValueError)
# >>> [ValueError()]

# To get the number of exceptions:
Expand All @@ -56,11 +56,11 @@ async def main():
try:
# To combines all of the exceptions into a single one, which merges the
# messages.
raise response.compound_exception
raise response.compound_exception()
except CompoundException as compound_exception:
print("Caught it")

if ValueError in compound_exception:
if ValueError in compound_exception.exception_types:
print("Caught a ValueError")

```
131 changes: 119 additions & 12 deletions asyncio_tools.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,134 @@
from __future__ import annotations
import asyncio

from functools import cached_property
import typing as t


class CompoundException(Exception):
"""
Is used to aggregate several exceptions into a single exception, with a
combined message. It contains a reference to the constituent exceptions.
"""

def __init__(self, exceptions: t.List[Exception]):
self.exceptions = exceptions

def __str__(self):
return (
f"CompoundException, {len(self.exceptions)} errors ["
+ "; ".join(
[
f"{i.__class__.__name__}: {i.__str__()}"
for i in self.exceptions
]
)
+ "]"
)

@cached_property
def exception_types(self) -> t.List[t.Type[Exception]]:
"""
Returns the constituent exception types.
Useful for checks like this:
if TransactionError in compound_exception.exception_types:
some_transaction_cleanup()
"""
return [i.__class__ for i in self.exceptions]


class GatheredResults:

__slots__ = ("results", "exceptions")
# __dict__ is required for cached_property
__slots__ = ("results", "__dict__")

def __init__(self, results):
def __init__(self, results: t.List[t.Any]):
self.results = results
self.exceptions = [
i for i in results if isinstance(i, Exception)
]

def has_exception(self, exception_class: Exception):
for i in self.exceptions:
if isinstance(i, exception_class):
return True
###########################################################################

def __setattr__(self, key, value):
"""
Since we use cached_properties for most of the lookups, we don't want
the underlying results to be changed. There should be no reason for a
user to want to change the results, but just to be sure we raise a
ValueError.
"""
if key == "results":
raise ValueError("results is immutable")
super().__setattr__(key, value)

@property
def all(self) -> t.List[t.Any]:
"""
Just a proxy.
"""
return self.results

###########################################################################

@cached_property
def exceptions(self) -> t.List[t.Type[Exception]]:
"""
Returns all exception instances which were returned by asyncio.gather.
"""
return [i for i in self.results if isinstance(i, Exception)]

def exceptions_of_type(
self, exception_type: t.Type[Exception]
) -> t.List[t.Type[Exception]]:
"""
Returns any exceptions of the given type.
"""
return [i for i in self.exceptions if isinstance(i, exception_type)]

@cached_property
def exception_types(self) -> t.List[t.Type[Exception]]:
"""
Returns the exception types which appeared in the response.
"""
return [i.__class__ for i in self.exceptions]

@cached_property
def exception_count(self) -> int:
return len(self.exceptions)

###########################################################################

@cached_property
def successes(self) -> t.List[t.Any]:
"""
Returns all values in the response which aren't exceptions.
"""
return [i for i in self.results if not isinstance(i, Exception)]

@cached_property
def success_count(self) -> int:
return len(self.successes)

###########################################################################

def compound_exception(self) -> t.Optional[CompoundException]:
"""
Create a single exception which combines all of the exceptions.
A function instead of a property to leave room for some extra args
in the future.
raise gathered_response.compound_exception()
"""
if not self.exceptions:
return False

def __contains__(self, exception_class: Exception):
return self.has_exception(exception_class)
return CompoundException(self.exceptions)


async def gather(*coroutines: t.Sequence[t.Coroutine]) -> GatheredResults:
"""
A wrapper on top of asyncio.gather which makes handling the results
easier.
"""
results = await asyncio.gather(*coroutines, return_exceptions=True)
return GatheredResults(results)
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
black==19.10b0
twine==3.1.1

Empty file added py.typed
Empty file.
37 changes: 25 additions & 12 deletions tests/test_gathered_results.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
import asyncio
from unittest import TestCase

from asyncio_tools import gather
from asyncio_tools import gather, CompoundException, GatheredResults


async def good():
return True
return "OK"


async def bad():
raise Exception()
raise ValueError("Bad value")


class TestGatheredResults(TestCase):

async def contains(self):
results = await gather(
good(),
bad(),
good(),
def test_exceptions(self):
response: GatheredResults = asyncio.run(gather(good(), bad(), good()))
self.assertTrue(ValueError in response.exception_types)
self.assertTrue(response.exception_count == 1)

def test_successes(self):
response: GatheredResults = asyncio.run(gather(good(), bad(), good()))
self.assertTrue(response.successes == ["OK", "OK"])
self.assertTrue(response.success_count == 2)

def test_compound_exception(self):
response: GatheredResults = asyncio.run(
gather(good(), bad(), good(), bad())
)

self.assertTrue(Exception in results)
with self.assertRaises(CompoundException):
raise response.compound_exception()

exception = response.compound_exception()
self.assertTrue(ValueError in exception.exception_types)

def test_contains(self):
asyncio.run(self.contains())
def test_set(self):
results = GatheredResults([])
with self.assertRaises(ValueError):
results.results = None

0 comments on commit 21d8505

Please sign in to comment.