From 2353e9f122df76d3e533874997c030b5a1f3a9e9 Mon Sep 17 00:00:00 2001 From: Milos Gligoric Date: Sun, 7 Jan 2024 12:58:25 -0600 Subject: [PATCH] Add support for the pass statement in trivial cases (empty body) --- pk | 14 +++++++++++--- pykokkos/core/cppast/serializer.py | 2 +- tests/test_AST_translator.py | 7 +++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pk b/pk index bbc14000..67374d97 100755 --- a/pk +++ b/pk @@ -48,14 +48,21 @@ function pk_docker_pull() { function _pk_cmd() { local -r name="${1}" + [ $# -lt 1 ] && return 1 + shift 1 + + [ -z "${name}" ] && \ + { echo "no name provided"; return 1; } ( cd "${CHOME}/pykokkos" export PYTHONPATH=pykokkos:$PYTHONPATH - python "${name}" ) + python "${name}" "$@" ) } function pk_example() { local -r name="${1}" + [ $# -lt 1 ] && return 1 + shift 1 [ -z "${name}" ] && \ { echo "no name provided (e.g., examples/kokkos-tutorials/workload/01.py)"; return 1; } @@ -67,11 +74,12 @@ function pk_example() { --volume $(pwd):"${CHOME}/pykokkos" \ --user pk:$(id -g) \ "${PROJECT}" \ - "${CHOME}/pykokkos/pk" "_pk_cmd" "${name}" + "${CHOME}/pykokkos/pk" "_pk_cmd" "${name}" \ + "$@" } function pk_tests() { - pk_example "runtests.py" + pk_example "runtests.py" "$@" } "$@" diff --git a/pykokkos/core/cppast/serializer.py b/pykokkos/core/cppast/serializer.py index bc539fad..96251592 100644 --- a/pykokkos/core/cppast/serializer.py +++ b/pykokkos/core/cppast/serializer.py @@ -300,7 +300,7 @@ def serialize_DeclStmt(self, node: DeclStmt) -> str: return self.serialize(node.decl) + ";" def serialize_EmptyStmt(self, node: EmptyStmt) -> str: - return "" + return "{}" def serialize_ForStmt(self, node: ForStmt) -> str: init: str = self.serialize(node.init) diff --git a/tests/test_AST_translator.py b/tests/test_AST_translator.py index 3a2f9154..c8ce17f7 100644 --- a/tests/test_AST_translator.py +++ b/tests/test_AST_translator.py @@ -110,6 +110,10 @@ def while_stmt(self, tid: int, acc: pk.Acc[pk.double]) -> None: acc += self.i_2 x += 1 + @pk.workunit + def pass_stmt(self, tid: int) -> None: + pass + @pk.workunit def call(self, tid: int, acc: pk.Acc[pk.double]) -> None: pk.printf("Testing printf: %d\n", self.i_1) @@ -269,6 +273,9 @@ def test_while_stmt(self): self.assertEqual(expected_result, result) + def test_pass(self): + pk.parallel_for(self.range_policy, self.functor.pass_stmt) + def test_call(self): expected_result: int = self.threads * abs(- self.i_1) result: int = pk.parallel_reduce(self.range_policy, self.functor.call)