forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add basic graphgym test * tensorboard * typo * update * add print * fix windows issues
- Loading branch information
Showing
6 changed files
with
124 additions
and
10 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
tensorboard_each_run: False | ||
tensorboard_agg: False | ||
dataset: | ||
format: PyG | ||
name: Cora | ||
task: node | ||
task_type: classification | ||
node_encoder: False | ||
node_encoder_name: Atom | ||
edge_encoder: False | ||
edge_encoder_name: Bond | ||
train: | ||
batch_size: 128 | ||
eval_period: 1 | ||
ckpt_period: 100 | ||
sampler: full_batch | ||
model: | ||
type: gnn | ||
loss_fun: cross_entropy | ||
edge_decoding: dot | ||
graph_pooling: add | ||
gnn: | ||
layers_pre_mp: 0 | ||
layers_mp: 2 | ||
layers_post_mp: 1 | ||
dim_inner: 16 | ||
layer_type: gcnconv | ||
stage_type: stack | ||
batchnorm: False | ||
act: prelu | ||
dropout: 0.1 | ||
agg: mean | ||
normalize_adj: False | ||
optim: | ||
optimizer: adam | ||
base_lr: 0.01 | ||
max_epoch: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os.path as osp | ||
import torch | ||
import random | ||
import sys | ||
import shutil | ||
|
||
from collections import namedtuple | ||
|
||
from torch_geometric import seed_everything | ||
from torch_geometric.graphgym.train import train | ||
from torch_geometric.graphgym.loader import create_loader | ||
from torch_geometric.graphgym.model_builder import create_model | ||
from torch_geometric.graphgym.models.head import GNNNodeHead | ||
from torch_geometric.graphgym.logger import set_printing, create_logger | ||
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage | ||
from torch_geometric.graphgym.config import (cfg, dump_cfg, set_run_dir, | ||
set_agg_dir, load_cfg) | ||
from torch_geometric.graphgym.utils import (agg_runs, params_count, | ||
auto_select_device) | ||
from torch_geometric.graphgym.optimizer import (create_optimizer, | ||
create_scheduler, | ||
OptimizerConfig, | ||
SchedulerConfig) | ||
|
||
|
||
def test_run_single_graphgym(): | ||
Args = namedtuple('Args', ['cfg_file', 'opts']) | ||
root = osp.join(osp.dirname(osp.realpath(__file__))) | ||
args = Args(osp.join(root, 'example_node.yml'), []) | ||
|
||
load_cfg(cfg, args) | ||
cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) | ||
cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) | ||
cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) | ||
dump_cfg(cfg) | ||
set_printing() | ||
|
||
seed_everything(cfg.seed) | ||
auto_select_device() | ||
set_run_dir(cfg.out_dir, args.cfg_file) | ||
|
||
loaders = create_loader() | ||
assert len(loaders) == 3 | ||
|
||
loggers = create_logger() | ||
assert len(loggers) == 3 | ||
|
||
model = create_model() | ||
assert isinstance(model, torch.nn.Module) | ||
assert isinstance(model.encoder, FeatureEncoder) | ||
assert isinstance(model.mp, GNNStackStage) | ||
assert isinstance(model.post_mp, GNNNodeHead) | ||
|
||
optimizer_config = OptimizerConfig(optimizer=cfg.optim.optimizer, | ||
base_lr=cfg.optim.base_lr, | ||
weight_decay=cfg.optim.weight_decay, | ||
momentum=cfg.optim.momentum) | ||
optimizer = create_optimizer(model.parameters(), optimizer_config) | ||
assert isinstance(optimizer, torch.optim.Adam) | ||
|
||
scheduler_config = SchedulerConfig(scheduler=cfg.optim.scheduler, | ||
steps=cfg.optim.steps, | ||
lr_decay=cfg.optim.lr_decay, | ||
max_epoch=cfg.optim.max_epoch) | ||
scheduler = create_scheduler(optimizer, scheduler_config) | ||
assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR) | ||
|
||
cfg.params = params_count(model) | ||
assert cfg.params == 23336 | ||
|
||
train(loggers, loaders, model, optimizer, scheduler) | ||
|
||
agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best) | ||
|
||
shutil.rmtree(cfg.out_dir) | ||
shutil.rmtree(cfg.dataset.dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters