Skip to content

Commit

Permalink
[PinSAGESampler] support PinSAGE sampler on GPU (#3567)
Browse files Browse the repository at this point in the history
* Feat: support API "randomwalk_topk" in library

* Feat: use the new API "randomwalk_topk" for PinSAGESampler

* Minor

* Minor

* Refactor: modified codes as checker required

* Minor

* Minor

* Minor

* Minor

* Fix: checking errors in RandomWalkTopk

* Refactor: modified the docstring for randomwalk_topk

* change randomwalk_topk to internal

* fix

* rename

* Minor for pinsage.py

* Feat: support randomwalk and SelectPinSageNeighbors on GPU

Port RandomWalk algorithm on GPU,
and port SelectPinSageNeighbors on GPU.

* Feat: support GPU on python APIs

* Feat: remove perf print information in FrequenchHashmap

* Fix: modified the code format

Modified the code format as task_lint.sh suggested

* Feat: let test script support PinSAGESampler on GPU

Let test script support PinSAGESampler on GPU,
minor of "restart_prob".

* Minor

* Minor

* Minor

* Refactor: use the atomic operations from the array module

* Minor: change the long lines

* Refactor: modified the get_node_types for gpu

* Feat: update the contributor date

* Perf: remove unnecessary stream sync

* Feat: support other random walk

But the non-uniform choice is still not supported.

* Fix: add CUDA switch for random walk

Co-authored-by: Quan Gan <[email protected]>
  • Loading branch information
lixiaobai09 and BarclayII authored Dec 15, 2021
1 parent 78e0dae commit dd762a1
Show file tree
Hide file tree
Showing 9 changed files with 920 additions and 16 deletions.
1 change: 1 addition & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ macro(dgl_config_cuda out_variable)
src/runtime/cuda/*.cu
src/geometry/cuda/*.cu
src/graph/transform/cuda/*.cu
src/graph/sampling/randomwalks/*.cu
)

# NVCC flags
Expand Down
6 changes: 3 additions & 3 deletions python/dgl/sampling/pinsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RandomWalkNeighborSampler(object):
Parameters
----------
G : DGLGraph
The graph. It must be on CPU.
The graph.
num_traversals : int
The maximum number of metapath-based traversals for a single random walk.
Expand Down Expand Up @@ -71,7 +71,6 @@ class RandomWalkNeighborSampler(object):
"""
def __init__(self, G, num_traversals, termination_prob,
num_random_walks, num_neighbors, metapath=None, weight_column='weights'):
assert G.device == F.cpu(), "Graph must be on CPU."
self.G = G
self.weight_column = weight_column
self.num_random_walks = num_random_walks
Expand All @@ -93,7 +92,8 @@ def __init__(self, G, num_traversals, termination_prob,
self.full_metapath = metapath * num_traversals
restart_prob = np.zeros(self.metapath_hops * num_traversals)
restart_prob[self.metapath_hops::self.metapath_hops] = termination_prob
self.restart_prob = F.zerocopy_from_numpy(restart_prob)
restart_prob = F.tensor(restart_prob, dtype=F.float32)
self.restart_prob = F.copy_to(restart_prob, G.device)

# pylint: disable=no-member
def __call__(self, seed_nodes):
Expand Down
9 changes: 2 additions & 7 deletions python/dgl/sampling/randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
Parameters
----------
g : DGLGraph
The graph. Must be on CPU.
The graph.
nodes : Tensor
Node ID tensor from which the random walk traces starts.
The tensor must be on CPU, and must have the same dtype as the ID type
The tensor must have the same dtype as the ID type
of the graph.
metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types.
Expand Down Expand Up @@ -85,10 +85,6 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
``(length + 1)``.
The type IDs match the ones in the original graph ``g``.
Notes
-----
The returned tensors are on CPU.
Examples
--------
The following creates a homogeneous graph:
Expand Down Expand Up @@ -160,7 +156,6 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
[ 2, 0, 1, 1, 3, 2, 2],
[ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0]))
"""
assert g.device == F.cpu(), "Graph must be on CPU."
n_etypes = len(g.canonical_etypes)
n_ntypes = len(g.ntypes)

Expand Down
Loading

0 comments on commit dd762a1

Please sign in to comment.