Skip to content

Commit

Permalink
Add, Multiply, Divide now support float32 scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
HannanNaeem committed Sep 10, 2024
1 parent afce4de commit 43148f4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
26 changes: 18 additions & 8 deletions pykokkos/lib/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def add_impl_1d_double(tid: int, viewA: pk.View1D[pk.double], viewB: pk.View1D[p

@pk.workunit
def add_impl_1d_float(tid: int, viewA: pk.View1D[pk.float], viewB: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = viewA[tid] + viewB[tid]
out[tid] = viewA[tid] + viewB[tid % viewB.extent(0)]

@pk.workunit
def add_impl_2d_1d(tid, viewA, viewB, out):
Expand Down Expand Up @@ -964,7 +964,7 @@ def multiply_impl_1d_double(tid: int, viewA: pk.View1D[pk.double], viewB: pk.Vie

@pk.workunit
def multiply_impl_1d_float(tid: int, viewA: pk.View1D[pk.float], viewB: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = viewA[tid] * viewB[tid]
out[tid] = viewA[tid] * viewB[tid % viewB.extent(0)]


@pk.workunit
Expand Down Expand Up @@ -998,8 +998,13 @@ def multiply(viewA, viewB, profiler_name: Optional[str] = None):
"""

# viewA must always be a view of type float 64 oe 32
if not isinstance(viewA, pk.ViewType) and viewA.dtype.__name__ not in ["float32", "float64"]:
raise RuntimeError("Incompatible first argument of type: {}, must be a float32 or float 64 Pykokkos view".format(viewA.dtype))

# then, if viewB is a scalar conform it to viewA's type
if not isinstance(viewB, pk.ViewType):
view_temp = pk.View([1], pk.double)
view_temp = pk.View([1], pk.double if viewA.dtype.__name__ == "float64" else pk.float32)
view_temp[0] = viewB
viewB = view_temp

Expand Down Expand Up @@ -1071,7 +1076,7 @@ def multiply(viewA, viewB, profiler_name: Optional[str] = None):
viewB=smaller,
out=out)
else:
raise RuntimeError("Incompatible Types")
raise RuntimeError("Incompatible Types {}, {}".format(viewA.dtype, viewB.dtype))
return out


Expand Down Expand Up @@ -1607,7 +1612,7 @@ def divide_impl_1d_double(tid: int, viewA: pk.View1D[pk.double], viewB: pk.View1

@pk.workunit
def divide_impl_1d_float(tid: int, viewA: pk.View1D[pk.float], viewB: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = viewA[tid] / viewB[tid]
out[tid] = viewA[tid] / viewB[tid % viewB.extent(0)]


@pk.workunit
Expand All @@ -1634,8 +1639,13 @@ def divide(viewA, viewB, profiler_name: Optional[str] = None):
Output view.
"""
if not isinstance(viewB, pk.ViewType) and not isinstance(viewB, pk.ViewType):
view_temp = pk.View([1], pk.double)
# viewA must always be a view of type float 64 oe 32
if not isinstance(viewA, pk.ViewType) and viewA.dtype.__name__ not in ["float32", "float64"]:
raise RuntimeError("Incompatible first argument of type: {}, must be a float32 or float 64 Pykokkos view".format(viewA.dtype))

# then, if viewB is a scalar conform it to viewA's type
if not isinstance(viewB, pk.ViewType):
view_temp = pk.View([1], pk.double if viewA.dtype.__name__ == "float64" else pk.float32)
view_temp[0] = viewB
viewB = view_temp

Expand Down Expand Up @@ -1669,7 +1679,7 @@ def divide(viewA, viewB, profiler_name: Optional[str] = None):
viewB=viewB,
out=out)
else:
raise RuntimeError("Incompatible Types")
raise RuntimeError("Incompatible Types {}, {}".format(viewA.dtype, viewB.dtype))
return out


Expand Down
23 changes: 23 additions & 0 deletions tests/test_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,29 @@ def test_multi_array_1d_exposed_ufuncs_vs_numpy(pk_ufunc,

assert_allclose(actual, expected)

# TODO: There may be more funcs that support scalars
@pytest.mark.parametrize("pk_ufunc, numpy_ufunc", [
(pk.add, np.add),
(pk.multiply, np.multiply),
(pk.divide, np.divide)
])
@pytest.mark.parametrize("numpy_dtype", [
np.float64,
np.float32
])
def test_scalar_operations_vs_numpy(pk_ufunc,
numpy_ufunc,
numpy_dtype):
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
expected = numpy_ufunc(
np.array(data, dtype=numpy_dtype),
1
)
actual = pk_ufunc(
pk.array(np.array(data, dtype=numpy_dtype)),
1
)
assert_allclose(actual, expected)

@pytest.mark.parametrize("pk_ufunc, numpy_ufunc", [
(pk.matmul, np.matmul),
Expand Down

0 comments on commit 43148f4

Please sign in to comment.