diff --git a/tests/test_gen_single.py b/tests/test_gen_single.py index 744fa05..ca4d638 100644 --- a/tests/test_gen_single.py +++ b/tests/test_gen_single.py @@ -28,11 +28,19 @@ def test_non_async_syntax_error(): _parse_code("foo = None\n bar = None", src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) +def test_no_async_roundtrip(): + code = "None" + assert not _needs_async_handling(code, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) + + code_unparsed = ast.unparse(_parse_code(code, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS)) + + assert code_unparsed == code + + @needs_ast_unparse @pytest.mark.parametrize( - ("code", "needs"), + "code", [ - pytest.param("None", False, id="no_async"), pytest.param( dedent( """ @@ -42,7 +50,6 @@ async def afn(): assert await afn() """ ), - True, id="await", ), pytest.param( @@ -55,7 +62,6 @@ async def agen(): assert item """ ), - True, id="async_for", ), pytest.param( @@ -67,7 +73,6 @@ async def agen(): assert [item async for item in agen()] == [True] """ ), - True, id="async_comprehension", ), pytest.param( @@ -83,23 +88,21 @@ async def acm(): assert ctx """ ), - True, id="async_context_manager", ), ], ) -def test_async_handling(code, needs): - assert _needs_async_handling(code, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) is needs +def test_async_handling(code): + assert _needs_async_handling(code, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) # Since AST objects are quite involved to compare, we unparse again and check that nothing has changed. Note that # since we are dealing with AST and not CST here, all whitespace is eliminated in the process and this needs to be # reflected in the input as well. code_stripped = "\n".join(line for line in code.splitlines() if line) code_unparsed = ast.unparse(_parse_code(code, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS)) - assert (code_unparsed == code_stripped) ^ needs + assert code_unparsed != code_stripped - if needs: - assert not _needs_async_handling(code_unparsed, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) + assert not _needs_async_handling(code_unparsed, src_file=SRC_FILE, compiler_flags=COMPILER_FLAGS) exec(COMPILER(code_unparsed, SRC_FILE, "exec"), make_globals())