diff --git a/include/fwd.hpp b/include/fwd.hpp index ee6bb31..15c1b8e 100644 --- a/include/fwd.hpp +++ b/include/fwd.hpp @@ -134,6 +134,7 @@ enum KokkosViewDataType { Uint64, Float32, Float64, + Bool, ViewDataTypesEnd }; diff --git a/include/traits.hpp b/include/traits.hpp index 83f9d1b..7d9e773 100644 --- a/include/traits.hpp +++ b/include/traits.hpp @@ -85,6 +85,7 @@ VIEW_DATA_TYPE(uint32_t, Uint32, "uint32", "unsigned", "unsigned_int") VIEW_DATA_TYPE(uint64_t, Uint64, "uint64", "unsigned_long") VIEW_DATA_TYPE(float, Float32, "float32", "float") VIEW_DATA_TYPE(double, Float64, "float64", "double") +VIEW_DATA_TYPE(bool, Bool, "bool", "bool_") //----------------------------------------------------------------------------// // diff --git a/kokkos/__init__.py.in b/kokkos/__init__.py.in index 1db0e0e..e3b3e57 100644 --- a/kokkos/__init__.py.in +++ b/kokkos/__init__.py.in @@ -166,6 +166,7 @@ try: "unsigned_long", "float", "double", + "bool", "Serial", # devices "Threads", "OpenMP", diff --git a/kokkos/test/views.py b/kokkos/test/views.py index d86aad2..10b82a5 100644 --- a/kokkos/test/views.py +++ b/kokkos/test/views.py @@ -140,7 +140,10 @@ def test_view_access(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) self.assertEqual(_data[0].create_mirror_view()[_idx], 1) - self.assertEqual(_data[1].create_mirror_view()[_idx], 2) + if _kwargs["dtype"] == kokkos.bool: + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[1].create_mirror_view()[_idx], 2) def test_view_iadd(self): """view_iadd""" @@ -169,8 +172,13 @@ def test_view_iadd(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 4) - self.assertEqual(_data[1].create_mirror_view()[_idx], 5) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 4) + self.assertEqual(_data[1].create_mirror_view()[_idx], 5) def test_view_isub(self): """view_isub""" @@ -199,8 +207,13 @@ def test_view_isub(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 7) - self.assertEqual(_data[1].create_mirror_view()[_idx], 17) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 7) + self.assertEqual(_data[1].create_mirror_view()[_idx], 17) def test_view_imul(self): """view_imul""" @@ -229,8 +242,13 @@ def test_view_imul(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 3) - self.assertEqual(_data[1].create_mirror_view()[_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 3) + self.assertEqual(_data[1].create_mirror_view()[_idx], 6) # def test_view_create_mirror(self): @@ -261,8 +279,12 @@ def test_view_create_mirror(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 3) - self.assertEqual(_data[1].create_mirror_view()[_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 3) + self.assertEqual(_data[1].create_mirror_view()[_idx], 6) _mirror_data = [ kokkos.create_mirror(_data[0], copy=True), @@ -271,8 +293,12 @@ def test_view_create_mirror(self): self.assertEqual(_mirror_data[0][_zeros], 0) self.assertEqual(_mirror_data[1][_zeros], 0) - self.assertEqual(_mirror_data[0][_idx], 3) - self.assertEqual(_mirror_data[1][_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + self.assertEqual(_mirror_data[0][_idx], True) + self.assertEqual(_mirror_data[1][_idx], True) + else: + self.assertEqual(_mirror_data[0][_idx], 3) + self.assertEqual(_mirror_data[1][_idx], 6) _mirror_data = [ kokkos.create_mirror(_data[0], copy=False), @@ -312,8 +338,13 @@ def test_view_create_mirror_view(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 3) - self.assertEqual(_data[1].create_mirror_view()[_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 3) + self.assertEqual(_data[1].create_mirror_view()[_idx], 6) _mirror_data = [ kokkos.create_mirror_view(_data[0], copy=True), @@ -322,8 +353,13 @@ def test_view_create_mirror_view(self): self.assertEqual(_mirror_data[0][_zeros], 0) self.assertEqual(_mirror_data[1][_zeros], 0) - self.assertEqual(_mirror_data[0][_idx], 3) - self.assertEqual(_mirror_data[1][_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_mirror_data[0][_idx], True) + self.assertEqual(_mirror_data[1][_idx], True) + else: + self.assertEqual(_mirror_data[0][_idx], 3) + self.assertEqual(_mirror_data[1][_idx], 6) _mirror_data = [ kokkos.create_mirror_view(_data[0], copy=False), @@ -334,12 +370,18 @@ def test_view_create_mirror_view(self): self.assertEqual(_mirror_data[1][_zeros], 0) if kokkos.get_host_accessible(_data[0].space): - self.assertEqual(_mirror_data[0][_idx], 3) + if _kwargs["dtype"] == kokkos.bool: + self.assertEqual(_mirror_data[0][_idx], True) + else: + self.assertEqual(_mirror_data[0][_idx], 3) else: self.assertNotEqual(_mirror_data[0][_idx], 3) if kokkos.get_host_accessible(_data[1].space): - self.assertEqual(_mirror_data[1][_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + self.assertEqual(_mirror_data[1][_idx], True) + else: + self.assertEqual(_mirror_data[1][_idx], 6) else: self.assertNotEqual(_mirror_data[1][_idx], 6) @@ -374,8 +416,13 @@ def test_view_deep_copy(self): self.assertEqual(_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_data[0].create_mirror_view()[_idx], 3) - self.assertEqual(_data[1].create_mirror_view()[_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_data[0].create_mirror_view()[_idx], 3) + self.assertEqual(_data[1].create_mirror_view()[_idx], 6) _copied_data = conf.generate_variant(_shape, **_kwargs) kokkos.deep_copy(_copied_data[0], _data[0]) @@ -383,8 +430,13 @@ def test_view_deep_copy(self): self.assertEqual(_copied_data[0].create_mirror_view()[_zeros], 0) self.assertEqual(_copied_data[1].create_mirror_view()[_zeros], 0) - self.assertEqual(_copied_data[0].create_mirror_view()[_idx], 3) - self.assertEqual(_copied_data[1].create_mirror_view()[_idx], 6) + if _kwargs["dtype"] == kokkos.bool: + # positive values are simply True + self.assertEqual(_copied_data[0].create_mirror_view()[_idx], True) + self.assertEqual(_copied_data[1].create_mirror_view()[_idx], True) + else: + self.assertEqual(_copied_data[0].create_mirror_view()[_idx], 3) + self.assertEqual(_copied_data[1].create_mirror_view()[_idx], 6) # main runner diff --git a/kokkos/utility.py b/kokkos/utility.py index 8bb2ee5..a1e357e 100644 --- a/kokkos/utility.py +++ b/kokkos/utility.py @@ -101,6 +101,8 @@ def read_dtype(_dtype): return lib.float32 elif _dtype == np.float64: return lib.float64 + elif _dtype == np.bool_: + return lib.bool except ImportError: pass diff --git a/src/variants/CMakeLists.txt b/src/variants/CMakeLists.txt index 955d462..57ff8f0 100644 --- a/src/variants/CMakeLists.txt +++ b/src/variants/CMakeLists.txt @@ -26,7 +26,7 @@ TARGET_LINK_LIBRARIES(libpykokkos-variants PUBLIC SET(_types concrete dynamic) SET(_variants layout memory_trait) -SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64) +SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64 Bool) SET(layout_enums Right) SET(memory_trait_enums Managed)