-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a3a1a6a
commit 21d8505
Showing
5 changed files
with
151 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,134 @@ | ||
from __future__ import annotations | ||
import asyncio | ||
|
||
from functools import cached_property | ||
import typing as t | ||
|
||
|
||
class CompoundException(Exception): | ||
""" | ||
Is used to aggregate several exceptions into a single exception, with a | ||
combined message. It contains a reference to the constituent exceptions. | ||
""" | ||
|
||
def __init__(self, exceptions: t.List[Exception]): | ||
self.exceptions = exceptions | ||
|
||
def __str__(self): | ||
return ( | ||
f"CompoundException, {len(self.exceptions)} errors [" | ||
+ "; ".join( | ||
[ | ||
f"{i.__class__.__name__}: {i.__str__()}" | ||
for i in self.exceptions | ||
] | ||
) | ||
+ "]" | ||
) | ||
|
||
@cached_property | ||
def exception_types(self) -> t.List[t.Type[Exception]]: | ||
""" | ||
Returns the constituent exception types. | ||
Useful for checks like this: | ||
if TransactionError in compound_exception.exception_types: | ||
some_transaction_cleanup() | ||
""" | ||
return [i.__class__ for i in self.exceptions] | ||
|
||
|
||
class GatheredResults: | ||
|
||
__slots__ = ("results", "exceptions") | ||
# __dict__ is required for cached_property | ||
__slots__ = ("results", "__dict__") | ||
|
||
def __init__(self, results): | ||
def __init__(self, results: t.List[t.Any]): | ||
self.results = results | ||
self.exceptions = [ | ||
i for i in results if isinstance(i, Exception) | ||
] | ||
|
||
def has_exception(self, exception_class: Exception): | ||
for i in self.exceptions: | ||
if isinstance(i, exception_class): | ||
return True | ||
########################################################################### | ||
|
||
def __setattr__(self, key, value): | ||
""" | ||
Since we use cached_properties for most of the lookups, we don't want | ||
the underlying results to be changed. There should be no reason for a | ||
user to want to change the results, but just to be sure we raise a | ||
ValueError. | ||
""" | ||
if key == "results": | ||
raise ValueError("results is immutable") | ||
super().__setattr__(key, value) | ||
|
||
@property | ||
def all(self) -> t.List[t.Any]: | ||
""" | ||
Just a proxy. | ||
""" | ||
return self.results | ||
|
||
########################################################################### | ||
|
||
@cached_property | ||
def exceptions(self) -> t.List[t.Type[Exception]]: | ||
""" | ||
Returns all exception instances which were returned by asyncio.gather. | ||
""" | ||
return [i for i in self.results if isinstance(i, Exception)] | ||
|
||
def exceptions_of_type( | ||
self, exception_type: t.Type[Exception] | ||
) -> t.List[t.Type[Exception]]: | ||
""" | ||
Returns any exceptions of the given type. | ||
""" | ||
return [i for i in self.exceptions if isinstance(i, exception_type)] | ||
|
||
@cached_property | ||
def exception_types(self) -> t.List[t.Type[Exception]]: | ||
""" | ||
Returns the exception types which appeared in the response. | ||
""" | ||
return [i.__class__ for i in self.exceptions] | ||
|
||
@cached_property | ||
def exception_count(self) -> int: | ||
return len(self.exceptions) | ||
|
||
########################################################################### | ||
|
||
@cached_property | ||
def successes(self) -> t.List[t.Any]: | ||
""" | ||
Returns all values in the response which aren't exceptions. | ||
""" | ||
return [i for i in self.results if not isinstance(i, Exception)] | ||
|
||
@cached_property | ||
def success_count(self) -> int: | ||
return len(self.successes) | ||
|
||
########################################################################### | ||
|
||
def compound_exception(self) -> t.Optional[CompoundException]: | ||
""" | ||
Create a single exception which combines all of the exceptions. | ||
A function instead of a property to leave room for some extra args | ||
in the future. | ||
raise gathered_response.compound_exception() | ||
""" | ||
if not self.exceptions: | ||
return False | ||
|
||
def __contains__(self, exception_class: Exception): | ||
return self.has_exception(exception_class) | ||
return CompoundException(self.exceptions) | ||
|
||
|
||
async def gather(*coroutines: t.Sequence[t.Coroutine]) -> GatheredResults: | ||
""" | ||
A wrapper on top of asyncio.gather which makes handling the results | ||
easier. | ||
""" | ||
results = await asyncio.gather(*coroutines, return_exceptions=True) | ||
return GatheredResults(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
black==19.10b0 | ||
twine==3.1.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,40 @@ | ||
import asyncio | ||
from unittest import TestCase | ||
|
||
from asyncio_tools import gather | ||
from asyncio_tools import gather, CompoundException, GatheredResults | ||
|
||
|
||
async def good(): | ||
return True | ||
return "OK" | ||
|
||
|
||
async def bad(): | ||
raise Exception() | ||
raise ValueError("Bad value") | ||
|
||
|
||
class TestGatheredResults(TestCase): | ||
|
||
async def contains(self): | ||
results = await gather( | ||
good(), | ||
bad(), | ||
good(), | ||
def test_exceptions(self): | ||
response: GatheredResults = asyncio.run(gather(good(), bad(), good())) | ||
self.assertTrue(ValueError in response.exception_types) | ||
self.assertTrue(response.exception_count == 1) | ||
|
||
def test_successes(self): | ||
response: GatheredResults = asyncio.run(gather(good(), bad(), good())) | ||
self.assertTrue(response.successes == ["OK", "OK"]) | ||
self.assertTrue(response.success_count == 2) | ||
|
||
def test_compound_exception(self): | ||
response: GatheredResults = asyncio.run( | ||
gather(good(), bad(), good(), bad()) | ||
) | ||
|
||
self.assertTrue(Exception in results) | ||
with self.assertRaises(CompoundException): | ||
raise response.compound_exception() | ||
|
||
exception = response.compound_exception() | ||
self.assertTrue(ValueError in exception.exception_types) | ||
|
||
def test_contains(self): | ||
asyncio.run(self.contains()) | ||
def test_set(self): | ||
results = GatheredResults([]) | ||
with self.assertRaises(ValueError): | ||
results.results = None |