Skip to content

Commit

Permalink
[Bugs] Fix distributed example error and import error (#3783)
Browse files Browse the repository at this point in the history
* fix

* raise an error

* fix docserver crash
  • Loading branch information
BarclayII authored Feb 28, 2022
1 parent 7ab1034 commit 62e23bd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 26 deletions.
7 changes: 1 addition & 6 deletions include/dgl/aten/array_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,25 +356,20 @@ IdArray VecToIdArray(const std::vector<T>& vec,
}

/*!
* \brief Get the context of the first non-null array, and check if the non-null arrays'
* \brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same.
*
* Throws an error if all the arrays are null arrays.
*/
inline DLContext GetContextOf(const std::vector<IdArray>& arrays) {
bool first = true;
DLContext result;
for (auto& array : arrays) {
if (IsNullArray(array))
continue;
if (first) {
first = false;
result = array->ctx;
} else {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different";
}
}
CHECK(!first) << "All input arrays are empty.";
return result;
}

Expand Down
2 changes: 1 addition & 1 deletion python/dgl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .data.utils import save_graphs, load_graphs
from . import optim
from .frame import LazyFeature
from .utils import recursive_apply
from .utils import apply_each

from ._deprecate.graph import DGLGraph as DGLGraphStale
from ._deprecate.nodeflow import *
9 changes: 7 additions & 2 deletions python/dgl/sampling/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,17 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
nodes = {g.ntypes[0] : nodes}

nodes = utils.prepare_tensor_dict(g, nodes, 'nodes')
if len(nodes) == 0:
raise ValueError(
"Got an empty dictionary in the nodes argument. "
"Please pass in a dictionary with empty tensors as values instead.")
ctx = utils.to_dgl_context(F.context(next(iter(nodes.values()))))
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
nodes_all_types.append(nd.array([], ctx=ctx))

if isinstance(fanout, nd.NDArray):
fanout_array = fanout
Expand Down Expand Up @@ -354,7 +359,7 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
if etype in exclude_edges:
excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))
else:
excluded_edges_all_t.append(nd.array([], ctx=nd.cpu()))
excluded_edges_all_t.append(nd.array([], ctx=ctx))

subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array,
edge_dir, prob_arrays, excluded_edges_all_t, replace)
Expand Down
23 changes: 12 additions & 11 deletions tutorials/multi/1_graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,18 @@ def main(rank, world_size, dataset, seed=0):

###############################################################################
# Finally we load the dataset and launch the processes.
#

if __name__ == '__main__':
import torch.multiprocessing as mp

from dgl.data import GINDataset

num_gpus = 4
procs = []
dataset = GINDataset(name='IMDBBINARY', self_loop=False)
mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus)
#
# .. code:: python
#
# if __name__ == '__main__':
# import torch.multiprocessing as mp
#
# from dgl.data import GINDataset
#
# num_gpus = 4
# procs = []
# dataset = GINDataset(name='IMDBBINARY', self_loop=False)
# mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus)

# Thumbnail credits: DGL
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'
13 changes: 7 additions & 6 deletions tutorials/multi/2_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,13 @@ def run(proc_id, devices):
# Python’s built-in ``multiprocessing`` except that it handles the
# subtleties between forking and multithreading in Python.
#

# Say you have four GPUs.
if __name__ == '__main__':
num_gpus = 4
import torch.multiprocessing as mp
mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus)
# .. code:: python
#
# # Say you have four GPUs.
# if __name__ == '__main__':
# num_gpus = 4
# import torch.multiprocessing as mp
# mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus)

# Thumbnail credits: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'

0 comments on commit 62e23bd

Please sign in to comment.