Skip to content

Commit

Permalink
add "force sync" escape hatch
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Nov 1, 2024
1 parent 85f77af commit fbd033f
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 45 deletions.
5 changes: 4 additions & 1 deletion src/prefect/_internal/compatibility/async_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __call__(
...

aio: Callable[P, Coroutine[Any, Any, R]]
sync: Callable[P, R]


def is_in_async_context() -> bool:
Expand All @@ -44,6 +45,7 @@ def async_dispatch(
- Return a coroutine when in an async context (detected via running event loop)
- Run synchronously when in a sync context
- Provide .aio for explicit async access
- Provide .sync for explicit sync access
Args:
async_impl: The async implementation to dispatch to when async execution
Expand All @@ -63,8 +65,9 @@ def wrapper(
return async_impl(*args, **kwargs)
return sync_fn(*args, **kwargs)

# Attach the async implementation directly
# Attach both async and sync implementations directly
wrapper.aio = async_impl
wrapper.sync = sync_fn
return wrapper # type: ignore

return decorator
137 changes: 93 additions & 44 deletions tests/_internal/compatibility/test_async_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,76 +9,125 @@
from prefect.utilities.asyncutils import run_sync_in_worker_thread


def test_async_compatible_fn_in_sync_context():
data = []
class TestAsyncDispatchBasicUsage:
def test_async_compatible_fn_in_sync_context(self):
data = []

async def my_function_async():
data.append("async")

async def my_function_async():
data.append("async")
@async_dispatch(my_function_async)
def my_function():
data.append("sync")

my_function()
assert data == ["sync"]

async def test_async_compatible_fn_in_async_context(self):
data = []

async def my_function_async():
data.append("async")

@async_dispatch(my_function_async)
def my_function():
data.append("sync")
@async_dispatch(my_function_async)
def my_function():
data.append("sync")

await my_function()
assert data == ["async"]


class TestAsyncDispatchExplicitUsage:
async def test_async_compatible_fn_explicit_async_usage(self):
"""Verify .aio property works as expected"""
data = []

async def my_function_async():
data.append("async")

@async_dispatch(my_function_async)
def my_function():
data.append("sync")

my_function()
assert data == ["sync"]
await my_function.aio()
assert data == ["async"]

def test_async_compatible_fn_explicit_async_usage_with_asyncio_run(self):
"""Verify .aio property works as expected with asyncio.run"""
data = []

async def test_async_compatible_fn_in_async_context():
data = []
async def my_function_async():
data.append("async")

async def my_function_async():
data.append("async")
@async_dispatch(my_function_async)
def my_function():
data.append("sync")

@async_dispatch(my_function_async)
def my_function():
data.append("sync")
asyncio.run(my_function.aio())
assert data == ["async"]

await my_function()
assert data == ["async"]
async def test_async_compatible_fn_explicit_sync_usage(self):
"""Verify .sync property works as expected in async context"""
data = []

async def my_function_async():
data.append("async")

async def test_async_compatible_fn_explicit_async_usage():
"""Verify .aio property works as expected"""
data = []
@async_dispatch(my_function_async)
def my_function():
data.append("sync")

async def my_function_async():
data.append("async")
# Even though we're in async context, .sync should force sync execution
my_function.sync()
assert data == ["sync"]

@async_dispatch(my_function_async)
def my_function():
data.append("sync")
def test_async_compatible_fn_explicit_sync_usage_in_sync_context(self):
"""Verify .sync property works as expected in sync context"""
data = []

await my_function.aio()
assert data == ["async"]
async def my_function_async():
data.append("async")

@async_dispatch(my_function_async)
def my_function():
data.append("sync")

def test_async_compatible_fn_explicit_async_usage_with_asyncio_run():
"""Verify .aio property works as expected with asyncio.run"""
data = []
my_function.sync()
assert data == ["sync"]

async def my_function_async():
data.append("async")

@async_dispatch(my_function_async)
def my_function():
data.append("sync")
class TestAsyncDispatchValidation:
def test_async_compatible_requires_async_implementation(self):
"""Verify we properly reject non-async implementations"""

asyncio.run(my_function.aio())
assert data == ["async"]
def not_async():
pass

with pytest.raises(TypeError, match="async_impl must be an async function"):

def test_async_compatible_requires_async_implementation():
"""Verify we properly reject non-async implementations"""
@async_dispatch(not_async)
def my_function():
pass

def not_async():
pass
async def test_async_compatible_fn_attributes_exist(self):
"""Verify both .sync and .aio attributes are present"""

with pytest.raises(TypeError, match="async_impl must be an async function"):
async def my_function_async():
pass

@async_dispatch(not_async)
@async_dispatch(my_function_async)
def my_function():
pass

assert hasattr(my_function, "sync"), "Should have .sync attribute"
assert hasattr(my_function, "aio"), "Should have .aio attribute"
assert (
my_function.sync is my_function.__wrapped__
), "Should reference original sync function"
assert (
my_function.aio is my_function_async
), "Should reference original async function"


class TestAsyncCompatibleFnCannotBeUsedWithAsyncioRun:
def test_async_compatible_fn_in_sync_context_errors_with_asyncio_run(self):
Expand Down

0 comments on commit fbd033f

Please sign in to comment.