Skip to content

Commit

Permalink
Views: store cupy array in View object when creating View from cp.nda…
Browse files Browse the repository at this point in the history
…rray (#255)
  • Loading branch information
NaderAlAwar authored Feb 7, 2024
1 parent a5d323a commit 9664cb7
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def __init__(
space: MemorySpace = MemorySpace.MemorySpaceDefault,
layout: Layout = Layout.LayoutDefault,
trait: Trait = Trait.TraitDefault,
array: Optional[np.ndarray] = None
array: Optional[np.ndarray] = None,
cp_array = None
):
"""
View constructor.
Expand All @@ -241,9 +242,10 @@ def __init__(
:param layout: the layout of the view in memory.
:param trait: the memory trait of the view
:param array: the numpy array if trait is Unmanaged
:param cp_array: the cupy array if trait is Unmanaged
"""

self._init_view(shape, dtype, space, layout, trait, array)
self._init_view(shape, dtype, space, layout, trait, array, cp_array)

def resize(self, dimension: int, size: int) -> None:
"""
Expand Down Expand Up @@ -298,7 +300,8 @@ def _init_view(
space: MemorySpace = MemorySpace.MemorySpaceDefault,
layout: Layout = Layout.LayoutDefault,
trait: Trait = Trait.TraitDefault,
array: Optional[np.ndarray] = None
array: Optional[np.ndarray] = None,
cp_array = None
) -> None:
"""
Initialize the view
Expand All @@ -309,6 +312,7 @@ def _init_view(
:param layout: the layout of the view in memory.
:param trait: the memory trait of the view
:param array: the numpy array if trait is Unmanaged
:param cp_array: the cupy array if trait is Unmanaged
"""

self.shape: Tuple[int] = tuple(shape)
Expand Down Expand Up @@ -353,6 +357,12 @@ def _init_view(
# invalidate the data. Currently, this happens when
# calling asarray()
self.orig_array = array

if cp_array is not None:
self.xp_array = cp_array
else:
self.xp_array = array

else:
if len(self.shape) == 0:
shape = [1]
Expand Down Expand Up @@ -610,13 +620,14 @@ def __hash__(self):
hash_value = hash(self.array)
return hash_value

def from_numpy(array: np.ndarray, space: Optional[MemorySpace] = None, layout: Optional[Layout] = None) -> ViewType:
def from_numpy(array: np.ndarray, space: Optional[MemorySpace] = None, layout: Optional[Layout] = None, cp_array = None) -> ViewType:
"""
Create a PyKokkos View from a numpy array
:param array: the numpy array
:param space: an optional argument for memory space (used by from_array)
:param layout: an optional argument for layout (used by from_array)
:param cp_array: the original cupy array (used by from_array)
:returns: a PyKokkos View wrapping the array
"""

Expand Down Expand Up @@ -674,7 +685,7 @@ def from_numpy(array: np.ndarray, space: Optional[MemorySpace] = None, layout: O
else:
ret_list = list((array.shape))

return View(ret_list, dtype, space=space, trait=Trait.Unmanaged, array=array, layout=layout)
return View(ret_list, dtype, space=space, trait=Trait.Unmanaged, array=array, layout=layout, cp_array=cp_array)

def from_array(array) -> ViewType:
"""
Expand Down Expand Up @@ -731,7 +742,7 @@ def from_array(array) -> ViewType:
elif km.get_gpu_framework() is pk.HIP:
memory_space = MemorySpace.HIPSpace

return from_numpy(np_array, memory_space, layout)
return from_numpy(np_array, memory_space, layout, array)

def is_array(array) -> bool:
"""
Expand Down

0 comments on commit 9664cb7

Please sign in to comment.