Skip to content

Commit

Permalink
ufuncs: update type checking in manipulate and util after changing vi…
Browse files Browse the repository at this point in the history
…ew.dtype type
  • Loading branch information
NaderAlAwar committed Dec 6, 2023
1 parent 41f8199 commit ffec823
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pykokkos/lib/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ravel_C_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View1D[

def ravel(view, order="C"):
if view.rank() == 2:
if str(view.dtype) == "DataType.double":
if view.dtype.__name__ == "float64":
out = pk.View([view.shape[0] * view.shape[1]], pk.double)
if order == "F":
pk.parallel_for(view.shape[1], ravel_F_impl_2d_double, view=view, out=out)
Expand Down Expand Up @@ -59,7 +59,7 @@ def expand_dims_1_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.V


def expand_dims(view, axis=0):
if str(view.dtype) != "DataType.double":
if view.dtype.__name__ == "float64":
raise RuntimeError("expand_dims supports views of type double only")

if view.rank() == 1:
Expand Down
4 changes: 2 additions & 2 deletions pykokkos/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def sum(viewA, axis=None):
return out


if str(viewA.dtype) == "DataType.double":
if viewA.dtype.__name__ == "float64":
return pk.parallel_reduce(
viewA.shape[0],
sum_impl_1d_double,
Expand Down Expand Up @@ -77,7 +77,7 @@ def col(view, col):
view_temp[0] = col
col = view_temp

if str(view.dtype) == "DataType.double":
if view.dtype.__name__ == "float64":
out = pk.View([view.shape[0]], pk.double)
pk.parallel_for(
view.shape[0],
Expand Down

0 comments on commit ffec823

Please sign in to comment.