diff --git a/PyG_to_sci_sparse.py b/PyG_to_sci_sparse.py index 702f8de..b4e9275 100644 --- a/PyG_to_sci_sparse.py +++ b/PyG_to_sci_sparse.py @@ -179,24 +179,26 @@ def prep_graph(dataset, num_nodes = data.__num_nodes__ else: num_nodes = data.num_nodes - - if hasattr(dataset, 'get_idx_split'): - split = dataset.get_idx_split() - else: - split = dict( - train=data.train_mask.nonzero().squeeze(), - valid=data.val_mask.nonzero().squeeze(), - test=data.test_mask.nonzero().squeeze() - ) + + # if hasattr(dataset, 'get_idx_split'): + # split = dataset.get_idx_split() + # else: + # split = dict( + # train=data.train_mask.nonzero().squeeze(), + # valid=data.val_mask.nonzero().squeeze(), + # test=data.test_mask.nonzero().squeeze() + # ) # converting to numpy arrays, so we don't have to handle different # array types (tensor/numpy/list) later on. # Also we need numpy arrays because Numba cant determine type of torch.Tensor - split = {k: v.numpy() for k, v in split.items()} + # split = {k: v.numpy() for k, v in split.items()} edge_index = data.edge_index.cpu() + print('edge_index', edge_index) if data.edge_attr is None: - edge_weight = torch.ones(edge_index.size(1)) + # edge_weight = torch.ones(edge_index.size(1)) + edge_weight = torch.full((edge_index.size(1),), 2.0) else: edge_weight = data.edge_attr edge_weight = edge_weight.cpu() @@ -223,7 +225,7 @@ def prep_graph(dataset, adj = torch_sparse.SparseTensor.from_scipy(adj).coalesce().to(device) attr_matrix = data.x.cpu().numpy() - print(adj) + print('prep_graph adj', adj) attr = torch.from_numpy(attr_matrix).to(device) logging.debug("Memory Usage after normalizing graph attributes:") @@ -250,6 +252,7 @@ def to_symmetric_scipy(adjacency: sp.csr_matrix): from torch_geometric.nn import GCN from torch_geometric.datasets import Amazon from torch_geometric.datasets import QM9, CoraFull, Planetoid +from ogb.nodeproppred import PygNodePropPredDataset import inspect from gnn_toolbox.common import is_directed coraa = CoraFull(root='./datasets', transform=T.ToSparseTensor()) @@ -270,10 +273,10 @@ def to_symmetric_scipy(adjacency: sp.csr_matrix): # dataset2 = Planetoid(name = 'cora', root = './datasets') # dataset5 = QM9(root='./datasets', transform=T.Compose([T.ToUndirected(), T.ToSparseTensor()])) -# dataset6 = QM9(root='./datasets') +# dataset6 = PygNodePropPredDataset(name='ogbn-arxiv', root='./datasets') # data6= dataset6[0] -# print('qm9', data6.edge_attr) - +# print('qm9', data6) +# print('qm9.edge_attr', data6.edge_attr) # from ogb.nodeproppred import PygNodePropPredDataset # dataset3 = PygNodePropPredDataset(name='ogbn-arxiv', root='./datasets') # dataset4 = PygNodePropPredDataset(name='ogbn-arxiv', root='./datasets', transform=T.ToUndirected()) @@ -281,7 +284,11 @@ def to_symmetric_scipy(adjacency: sp.csr_matrix): # # print('pyg', data) # data3 = dataset3[0] # print('is_directed ogb', data3.is_directed()) -# attr, adj, labels, splits, n = prep_graph(dataset4, make_undirected=False) + +# attr, adj, labels, splits, n = prep_graph(dataset6, make_undirected=False) +# row, col, edge_attr2 = adj.coo() +# edge_index = torch.stack([row, col], dim=0) +# print('after edge_attr', edge_attr2) # print('pyg prep_graph adj', adj) # attr, adj, labels, splits, n = prep_graph(dataset3, make_undirected=True) # print('robu prep_graph adj', adj) @@ -295,20 +302,119 @@ def to_symmetric_scipy(adjacency: sp.csr_matrix): # # model = DICE() # amazon = Amazon(root='./datasets', name='computers') # data9 = amazon[0] +def accuracy(pred, y, mask): + return (pred.argmax(-1)[mask] == y[mask]).float().mean() # print(data9.is_directed()) -# dataset1= Planetoid(name = 'cora', root = './datasets', transform=T.ToUndirected()) -# data = dataset1[0] -# print('data.edge_index', data.edge_index) +dataset2= PygNodePropPredDataset(name='ogbn-arxiv', root='./datasets', transform=T.ToSparseTensor(remove_edge_index=False)) +dataset1= Planetoid(name = 'cora', root = './datasets') +data = dataset1[0] +train=data.train_mask.nonzero().squeeze() +model1= GCN(in_channels=dataset1.num_features, out_channels=dataset1.num_classes, hidden_channels=32, num_layers=2) +logits = model1(data.x, data.edge_index) +# print(accuracy(logits, data.y, train)) +# print(accuracy(logits, data.y, data.train_mask)) +# from torch_geometric.datasets import PPI +# dataset1 = PPI(root='./datasets', transform=T.ToSparseTensor(remove_edge_index=False)) +# print('pygnodepred', dataset1) +# we = dataset1[0] +# print('we', we) + +# dataset2= Planetoid(name = 'citeseer', root = './datasets') +# dataset2= QM9(root = './datasets') +data2 = dataset2[0] +# train_mask=data2.train_mask.nonzero().squeeze() +# attr, adj, labels, splits, n = prep_graph(dataset2, make_undirected=False) +# print() +from torch_geometric.utils import to_undirected, is_undirected +print('is_undirected', is_undirected(data2.edge_index, num_nodes=data2.num_nodes)) +adj = SparseTensor(row=data2.edge_index[0], col=data2.edge_index[1], value=data2.edge_attr).to('cuda') +adj2 = SparseTensor(row=data2.edge_index[0], col=data2.edge_index[1], value=data2.edge_weight if data2.edge_weight is not None else torch.ones(data2.edge_index.size(1))).to('cuda') +print('num_edge',adj.nnz()) +print('what', adj) +print('what2', adj.t()) + +row, col, edge_attr = adj.t().coo() +edge_index = torch.stack([row, col], dim=0) +print('edge_index3', edge_index) + +# row, col, edge_attr = adj.coo() +# edge_index = torch.stack([row, col], dim=0) +# print('edge_index1', edge_index) -# print('data2.edge_index', data2.adj_t) -# row, col, edge_attr = adj_t.t().coo() -# edge_index = torch.stack([row, col], dim=0) +edge_index = to_undirected(data2.edge_index) +print('x',edge_index.shape[1]) +print('undirected', edge_index) +print(torch.equal(data2.edge_index, edge_index)) +print("numm", data2.num_nodes) + +print("numm2", data2.x.shape[0]) +print("numm2", data2.x.size(0)) +# row, col, edge_attr = adj.coo() +# cora_edge_index = torch.stack([row, col], dim=0) +# print('edge_attr', edge_attr) +# data = dataset1[0] +# row, col, edge_attr = data.adj_t.coo() +# cora_edge_index = torch.stack([row, col], dim=0) + +# # cora_edge_index = cora_edge_index.to('cpu') +# if edge_attr is None: +# edge_attr = torch.ones(cora_edge_index.size(1)) # row2, col2, edge_attr2 = adj_t.coo() # edge_index2 = torch.stack([row2, col2], dim=0) # print('data2 edge_index', edge_index2) # print('are they equal:', torch.equal(edge_index, edge_index2)) -# model1= GCN(in_channels=dataset.num_features, out_channels=dataset.num_classes, hidden_channels=16, num_layers=2, dropout=0.5) + +import torch.nn as nn +from torch_geometric.nn import GCNConv + +def accuracy(pred, y, mask): + return (pred.argmax(-1)[mask] == y[mask]).float().mean() + +# @register_model("GCN") +class GCNWH(nn.Module): + def __init__(self, in_channels: int, out_channels: int, hidden_channels: int, num_layers: int = 2, dropout: float = 0.5, **kwargs): + super().__init__() + self.GCNConv1 = GCNConv(in_channels=in_channels, out_channels=hidden_channels) + self.layers = nn.ModuleList([GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)]) + self.dropout = nn.Dropout(dropout) + self.GCNConv2 = GCNConv(in_channels=hidden_channels, out_channels=out_channels) + + def forward(self, x, edge_index, edge_weight, **kwargs): + x = self.GCNConv1(x, edge_index, edge_weight).relu() + x = F.dropout(x, training=self.training) + x = self.GCNConv2(x, edge_index, edge_weight) + return x + +def train3(model, attr, edge_index, data, label,edge_weight=None, epochs=200, lr=0.01, weight_decay=5e-4): + model = model.to('cuda') + res = [] + model.train() + optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + for module in model.children(): + module.reset_parameters() + for _ in range(epochs): + optimizer.zero_grad() + pred = model(attr, edge_index, edge_weight) + loss = F.cross_entropy(pred[data.train_mask], label[data.train_mask]) + acc = float(accuracy(pred, label, data.train_mask)) + res.append(acc) + loss.backward() + optimizer.step() + + with torch.no_grad(): + model.eval() + pred = model(attr, edge_index, edge_weight) + acc = float(accuracy(pred, label, data.train_mask)) + print('model2 acc', acc) + return res + +# print(train3(model1, data.x, cora_edge_index, edge_attr, data, data.y)) +# from pprint import pprint +# res =train3(model1,attr, cora_edge_index, data2, labels, edge_attr) +# pprint(res) + +# train2(model1,) # model2 = GCN(in_channels=dataset.num_features, out_channels=dataset.num_classes, hidden_channels=16, num_layers=2, dropout=0.5) def train(model, data, epochs=200, lr=0.01, weight_decay=5e-4): @@ -321,8 +427,7 @@ def train(model, data, epochs=200, lr=0.01, weight_decay=5e-4): loss.backward() optimizer.step() -def accuracy(pred, y, mask): - return (pred.argmax(-1)[mask] == y[mask]).float().mean() + @torch.no_grad() def test(model, data): @@ -340,7 +445,7 @@ def train2(model, attr, adj, split, label, epochs=200, lr=0.01, weight_decay=5e- optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for _ in range(epochs): optimizer.zero_grad() - pred = model(attr, adj.t()) + pred = model(attr, adj) loss = F.cross_entropy(pred[split['train']], label[split['train']]) loss.backward() optimizer.step() diff --git a/configs/default_experiment.yaml b/configs/default_experiment.yaml index 4d7d4bf..46de0db 100644 --- a/configs/default_experiment.yaml +++ b/configs/default_experiment.yaml @@ -1,6 +1,7 @@ -output_dir: ./output2 +output_dir: ./output6 # resume_output: False # csv_save: True +# esd: True experiment_templates: - name: My experiment seed: [0] @@ -8,13 +9,13 @@ experiment_templates: model: - name: GCN params: - hidden_channels: 32 + hidden_channels: 64 dataset: - name: Cora + name: Cora #ogbn-arxiv root: ./datasets - make_undirected: True - transforms: - - name: NormalizeFeatures + make_undirected: False + # transforms: + # - name: NormalizeFeatures # - name: Constant # params: # value: 0.8 @@ -36,20 +37,21 @@ experiment_templates: # name: DICE, FGSM, PGD, PRBCD, GreedyRBCD] # type: [poison, evasion] # epsilon: [0.5] - - scope: global - name: [PRBCD, GreedyRBCD] + - scope: local + name: [LocalPRBCD, LocalDICE] type: poison - epsilon: [0.5] + epsilon: [0.3] + nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # - scope: local - # name: localDICE + # name: Nettack # type: evasion # epsilon: [0.2] - # nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + # nodes: [0] training: - max_epochs: 300 - patience: 100 + max_epochs: 100 + patience: 70 optimizer: name: adam diff --git a/gnn_toolbox/common.py b/gnn_toolbox/common.py index df11924..f80abf6 100644 --- a/gnn_toolbox/common.py +++ b/gnn_toolbox/common.py @@ -10,18 +10,30 @@ """ import logging +import inspect import numpy as np import torch import torch.nn.functional as F import scipy.sparse as sp from tqdm.auto import tqdm -from typing import Tuple, Union, Optional, List, Dict, Any +from typing import Tuple, Union, Optional, List, Dict, Any from torchtyping import TensorType from torch_sparse import SparseTensor - - - -def train(model, attr, adj, labels, idx_train, idx_val, idx_test, optimizer, loss, max_epochs, patience): +from torch_geometric.utils import to_undirected + +def train( + model, + attr, + adj, + labels, + idx_train, + idx_val, + idx_test, + optimizer, + loss, + max_epochs, + patience, +): """Train a model using either standard training. Parameters ---------- @@ -59,16 +71,30 @@ def train(model, attr, adj, labels, idx_train, idx_val, idx_test, optimizer, los # trace_acc_test = [] results = [] # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + # from torch_geometric.datasets import Planetoid + # from torch_geometric.transforms import ToSparseTensor + # cora = Planetoid( + # root="/datasets", name="Cora", transform=ToSparseTensor(remove_edge_index=False) + # ) + # qw = cora[0] best_loss = np.inf - + edge_index_rows, edge_index_cols, edge_weight = adj.coo() + edge_index = torch.stack([edge_index_rows, edge_index_cols], dim=0) + # edge_weight2 = edge_weight2.float() + # attr = attr.float() + # .to(self.device) model.train() - for epoch in tqdm(range(max_epochs), desc='Training...'): + for epoch in tqdm(range(max_epochs), desc="Training"): optimizer.zero_grad() - logits = model(attr, adj) - - + logits = get_logits(model, attr, edge_index, edge_weight) + # logits = model( + # qw.x.to("cuda"), + # qw.edge_index.to("cuda"), + # torch.ones_like(qw.edge_index[0], dtype=torch.float32).to("cuda"), + # ) + loss_train = loss(logits[idx_train], labels[idx_train]) loss_val = loss(logits[idx_val], labels[idx_val]) @@ -85,7 +111,7 @@ def train(model, attr, adj, labels, idx_train, idx_val, idx_test, optimizer, los # trace_acc_train.append(train_acc) # trace_acc_val.append(val_acc) # trace_acc_test.append(test_acc) - + if loss_val < best_loss: best_loss = loss_val best_epoch = epoch @@ -107,47 +133,56 @@ def train(model, attr, adj, labels, idx_train, idx_val, idx_test, optimizer, los # restore the best validation state model.load_state_dict(best_state) # save the model - + return results +def get_logits(model, attr, edge_index, edge_weight, idx=None): + sig = inspect.signature(model.forward) + # print('sig.parameters', sig.parameters) + if "edge_weight" in sig.parameters or "edge_attr" in sig.parameters: + logits = model(attr, edge_index, edge_weight) + else: + logits = model(attr, edge_index) -def classification_statistics(logits: TensorType[1, "n_classes"], - label: TensorType[()]) -> Dict[str, float]: - logits, label = F.log_softmax(logits.cpu(), dim=-1), label.cpu() - logits = logits[0] - logit_target = logits[label].item() - sorted = logits.argsort() - logit_best_non_target = (logits[sorted[sorted != label][-1]]).item() - confidence_target = np.exp(logit_target) - confidence_non_target = np.exp(logit_best_non_target) - margin = confidence_target - confidence_non_target - return { - 'logit_target': logit_target, - 'logit_best_non_target': logit_best_non_target, - 'confidence_target': confidence_target, - 'confidence_non_target': confidence_non_target, - 'margin': margin - } - -def gen_local_attack_nodes(attr, adj, labels, model, idx_test, device, topk=10, min_node_degree=2): + if idx is not None: + logits = logits[idx] + return logits + + +def gen_local_attack_nodes( + attr, adj, labels, model, idx_test, device, topk=10, min_node_degree=2 +): logits, acc = evaluate_model(model, attr, adj, labels, idx_test, device) - + logging.info(f"Sample Attack Nodes for model with accuracy {acc:.4}") - max_confidence_nodes_idx, min_confidence_nodes_idx, rand_nodes_idx = sample_attack_nodes( - logits, labels[idx_test], idx_test, adj, topk, min_node_degree) - tmp_nodes = np.concatenate((max_confidence_nodes_idx, min_confidence_nodes_idx, rand_nodes_idx)) + max_confidence_nodes_idx, min_confidence_nodes_idx, rand_nodes_idx = ( + sample_attack_nodes( + logits, labels[idx_test], idx_test, adj, topk, min_node_degree + ) + ) + tmp_nodes = np.concatenate( + (max_confidence_nodes_idx, min_confidence_nodes_idx, rand_nodes_idx) + ) logging.info( - f"Sample the following attack nodes:\n{max_confidence_nodes_idx}\n{min_confidence_nodes_idx}\n{rand_nodes_idx}") + f"Sample the following attack nodes:\n{max_confidence_nodes_idx}\n{min_confidence_nodes_idx}\n{rand_nodes_idx}" + ) return tmp_nodes -def sample_attack_nodes(logits: torch.Tensor, labels: torch.Tensor, nodes_idx, - adj: SparseTensor, topk: int, min_node_degree: int): + +def sample_attack_nodes( + logits: torch.Tensor, + labels: torch.Tensor, + nodes_idx, + adj: SparseTensor, + topk: int, + min_node_degree: int, +): assert logits.shape[0] == labels.shape[0] if isinstance(nodes_idx, torch.Tensor): nodes_idx = nodes_idx.cpu() node_degrees = adj[nodes_idx.tolist()].sum(-1) - print('len(node_degrees)', len(node_degrees)) + print("len(node_degrees)", len(node_degrees)) suitable_nodes_mask = (node_degrees >= min_node_degree).cpu() labels = labels.cpu()[suitable_nodes_mask] @@ -158,48 +193,87 @@ def sample_attack_nodes(logits: torch.Tensor, labels: torch.Tensor, nodes_idx, logging.info( f"Found {sum(suitable_nodes_mask)} suitable '{min_node_degree}+ degree' nodes out of {len(nodes_idx)} " f"candidate nodes to be sampled from for the attack of which {correctly_classifed.sum().item()} have the " - "correct class label") + "correct class label" + ) print( f"Found {sum(suitable_nodes_mask)} suitable '{min_node_degree}+ degree' nodes out of {len(nodes_idx)} " f"candidate nodes to be sampled from for the attack of which {correctly_classifed.sum().item()} have the " - "correct class label") + "correct class label" + ) print(sum(suitable_nodes_mask)) - assert sum(suitable_nodes_mask) >= (topk * 4), \ - f"There are not enough suitable nodes to sample {(topk*4)} nodes from" + assert sum(suitable_nodes_mask) >= ( + topk * 4 + ), f"There are not enough suitable nodes to sample {(topk*4)} nodes from" - _, max_confidence_nodes_idx = torch.topk(confidences[correctly_classifed].max(-1).values, k=topk) - _, min_confidence_nodes_idx = torch.topk(-confidences[correctly_classifed].max(-1).values, k=topk) + _, max_confidence_nodes_idx = torch.topk( + confidences[correctly_classifed].max(-1).values, k=topk + ) + _, min_confidence_nodes_idx = torch.topk( + -confidences[correctly_classifed].max(-1).values, k=topk + ) rand_nodes_idx = np.arange(correctly_classifed.sum().item()) rand_nodes_idx = np.setdiff1d(rand_nodes_idx, max_confidence_nodes_idx) rand_nodes_idx = np.setdiff1d(rand_nodes_idx, min_confidence_nodes_idx) rnd_sample_size = min((topk * 2), len(rand_nodes_idx)) - rand_nodes_idx = np.random.choice(rand_nodes_idx, size=rnd_sample_size, replace=False) - - return (np.array(nodes_idx[suitable_nodes_mask][correctly_classifed][max_confidence_nodes_idx])[None].flatten(), - np.array(nodes_idx[suitable_nodes_mask][correctly_classifed][min_confidence_nodes_idx])[None].flatten(), - np.array(nodes_idx[suitable_nodes_mask][correctly_classifed][rand_nodes_idx])[None].flatten()) - + rand_nodes_idx = np.random.choice( + rand_nodes_idx, size=rnd_sample_size, replace=False + ) + + return ( + np.array( + nodes_idx[suitable_nodes_mask][correctly_classifed][ + max_confidence_nodes_idx + ] + )[None].flatten(), + np.array( + nodes_idx[suitable_nodes_mask][correctly_classifed][ + min_confidence_nodes_idx + ] + )[None].flatten(), + np.array(nodes_idx[suitable_nodes_mask][correctly_classifed][rand_nodes_idx])[ + None + ].flatten(), + ) + +def to_edge_index(adj, device): + if isinstance(adj, SparseTensor): + edge_index_rows, edge_index_cols, edge_weight = adj.coo() + edge_index = torch.stack([edge_index_rows, edge_index_cols], dim=0).to(device) + return edge_index, edge_weight.to(device) @torch.no_grad() -def evaluate_model(model, - attr: TensorType["n_nodes", "n_features"], - adj: Union[SparseTensor, TensorType["n_nodes", "n_nodes"]], - labels: TensorType["n_nodes"], - idx_test: Union[List[int], np.ndarray], - device): +def evaluate_model( + model, + attr: TensorType["n_nodes", "n_features"], + adj: Union[SparseTensor, TensorType["n_nodes", "n_nodes"]], + labels: TensorType["n_nodes"], + idx_test: Union[List[int], np.ndarray], + device: str, +): """ Evaluates any model w.r.t. accuracy for a given (perturbed) adjacency and attribute matrix. """ model.eval() - pred_logits_target = model(attr, adj)[idx_test] - - acc_test_target = accuracy(pred_logits_target.cpu(), labels.cpu()[idx_test], np.arange(pred_logits_target.shape[0])) + edge_index, edge_weight = to_edge_index(adj, device) + + + # pred_logits_target = model(attr, adj, edge_weight)[idx_test] + pred_logits_target = get_logits(model, attr, edge_index, edge_weight, idx_test) + + acc_test_target = accuracy( + pred_logits_target.cpu(), + labels.cpu()[idx_test], + np.arange(pred_logits_target.shape[0]), + ) return pred_logits_target, acc_test_target -def accuracy(logits: torch.Tensor, labels: torch.Tensor, split_idx: np.ndarray) -> float: + +def accuracy( + logits: torch.Tensor, labels: torch.Tensor, split_idx: np.ndarray +) -> float: """Returns the accuracy for a tensor of logits, a list of lables and and a split indices. Parameters @@ -218,6 +292,7 @@ def accuracy(logits: torch.Tensor, labels: torch.Tensor, split_idx: np.ndarray) """ return (logits.argmax(1)[split_idx] == labels[split_idx]).float().mean().item() + def random_splitter(labels, n_per_class=20, seed=None): """ Randomly split the training data. @@ -248,25 +323,31 @@ def random_splitter(labels, n_per_class=20, seed=None): for label in range(nc): perm = np.random.permutation((labels == label).nonzero()[0]) split_train.append(perm[:n_per_class]) - split_val.append(perm[n_per_class:2 * n_per_class]) + split_val.append(perm[n_per_class : 2 * n_per_class]) split_train = np.random.permutation(np.concatenate(split_train)) split_val = np.random.permutation(np.concatenate(split_val)) assert split_train.shape[0] == split_val.shape[0] == n_per_class * nc - split_test = np.setdiff1d(np.arange(len(labels)), np.concatenate((split_train, split_val))) + split_test = np.setdiff1d( + np.arange(len(labels)), np.concatenate((split_train, split_val)) + ) return dict(split_train, split_val, split_test) -def prepare_dataset(dataset, - experiment: Dict[str, Any], - graph_index: int, - make_undirected: bool, - ) -> Tuple[TensorType["num_nodes", "num_features"], - SparseTensor, - TensorType["num_nodes"], - Optional[Dict[str, np.ndarray]]]: + +def prepare_dataset( + dataset, + experiment: Dict[str, Any], + graph_index: int, + make_undirected: bool, +) -> Tuple[ + TensorType["num_nodes", "num_features"], + SparseTensor, + TensorType["num_nodes"], + Optional[Dict[str, np.ndarray]], +]: """Prepares and normalizes the desired dataset Parameters @@ -287,66 +368,89 @@ def prepare_dataset(dataset, Tuple[torch.Tensor, torch_sparse.SparseTensor, torch.Tensor] dense attribute tensor, sparse adjacency matrix (normalized) and labels tensor. """ - + logging.debug("Memory Usage before loading the dataset:") logging.debug(torch.cuda.memory_allocated(device=None) / (1024**3)) if graph_index is None: graph_index = 0 - + data = dataset[graph_index] - if hasattr(data, '__num_nodes__'): - num_nodes = data.__num_nodes__ - else: + if hasattr(data, "num_nodes"): num_nodes = data.num_nodes + else: + num_nodes = data.x.shape[0] # converting to numpy arrays, so we don't have to handle different array types (tensor/numpy/list) later on. # Also we need numpy arrays because Numba cant determine type of torch.Tensor - - device = experiment['device'] - edge_index = data.edge_index - if data.edge_attr: + + device = experiment["device"] + # edge_index = data.edge_index + if data.edge_attr is not None: edge_weight = data.edge_attr - elif data.edge_weight: + elif data.edge_weight is not None: edge_weight = data.edge_weight else: - edge_weight = torch.ones(edge_index.size(1)) - edge_weight = edge_weight.cpu() - - adj = sp.csr_matrix((edge_weight, edge_index), (num_nodes, num_nodes)) - - del edge_index - del edge_weight - - # make unweighted - adj.data = np.ones_like(adj.data) - + edge_weight = torch.ones(data.edge_index.shape[1]) + + num_edges = data.edge_index.size(1) + edge_index = data.edge_index if make_undirected: - adj = to_symmetric_scipy(adj) - num_edges = adj.nnz / 2 + edge_index, edge_weight = to_undirected(edge_index, edge_weight, num_nodes, reduce="mean") + num_edges = edge_index.shape[1] logging.debug("Memory Usage after making the graph undirected:") logging.debug(torch.cuda.memory_allocated(device=None) / (1024**3)) - else: - num_edges = adj.nnz - adj = SparseTensor.from_scipy(adj).coalesce().to(device) + adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight).t().to(device) + + # edge_weight = edge_weight.cpu() + + # adj = sp.csr_matrix((edge_weight, edge_index), (num_nodes, num_nodes)) + + # del edge_index + # # del edge_weight + + # # make unweighted + # adj.data = np.ones_like(adj.data) + + # if make_undirected: + # adj = to_symmetric_scipy(adj) + # num_edges = adj.nnz / 2 + # logging.debug("Memory Usage after making the graph undirected:") + # logging.debug(torch.cuda.memory_allocated(device=None) / (1024**3)) + # else: + # num_edges = adj.nnz + + # adj = SparseTensor.from_scipy(adj).coalesce().to(device) + # if edge_weight.dtype != torch.float32: + # edge_weight = edge_weight.float() + # if make_undirected: + # num_edges = adj.nnz() / 2 + # else: + # num_edges = adj.nnz() attr_matrix = data.x.cpu().numpy() attr = torch.from_numpy(attr_matrix).to(device) - + # edge_weight = edge_weight.cpu().numpy() + # edge_weight = torch.from_numpy(edge_weight).to(device) + labels = data.y.squeeze().to(device) - split = splitter(dataset, data, labels, experiment['seed']) + split = splitter(dataset, data, labels, experiment["seed"]) split = {k: v.numpy() for k, v in split.items()} - - experiment['model']['params'].update({ - 'in_channels': attr.size(1), - 'out_channels': int(labels[~labels.isnan()].max() + 1) - }) - + + experiment["model"]["params"].update( + { + "in_channels": attr.shape[1], + "out_channels": int(labels[~labels.isnan()].max() + 1), + } + ) + return attr, adj, labels, split, num_edges +# , edge_weight + def splitter(dataset, data, labels, seed): """ @@ -361,22 +465,26 @@ def splitter(dataset, data, labels, seed): Returns: A dictionary containing the indices of the train, validation, and test sets. """ - if hasattr(dataset, 'get_idx_split'): + if hasattr(dataset, "get_idx_split"): split = dataset.get_idx_split() - logging.info(f"Using the provided split from get_idx_split().") + logging.debug(f"Using the provided split from get_idx_split().") + return split else: try: split = dict( train=data.train_mask.nonzero().squeeze(), valid=data.val_mask.nonzero().squeeze(), - test=data.test_mask.nonzero().squeeze() + test=data.test_mask.nonzero().squeeze(), ) - logging.info(f"Using the provided split with train, val, test mask.") + logging.debug(f"Using the provided split with train, val, test mask.") return split except AttributeError: - logging.info(f"Dataset doesn't provide train, val, test splits. Using random_splitter() for the splitting.") + logging.debug( + f"Dataset doesn't provide train, val, test splits. Using random_splitter() for the splitting." + ) return random_splitter(labels=labels.cpu().numpy(), seed=seed) + def to_symmetric_scipy(adjacency: sp.csr_matrix): sym_adjacency = (adjacency + adjacency.T).astype(bool).astype(float) @@ -384,7 +492,7 @@ def to_symmetric_scipy(adjacency: sp.csr_matrix): return sym_adjacency + def is_directed(adj_matrix) -> bool: - """Check if the graph is directed (adjacency matrix is not symmetric). - """ - return (adj_matrix != adj_matrix.t()).sum() != 0 \ No newline at end of file + """Check if the graph is directed (adjacency matrix is not symmetric).""" + return (adj_matrix != adj_matrix.t()).sum() != 0 diff --git a/gnn_toolbox/custom_modules/__init__.py b/gnn_toolbox/custom_components/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/__init__.py rename to gnn_toolbox/custom_components/__init__.py diff --git a/gnn_toolbox/custom_modules/attacks/__init__.py b/gnn_toolbox/custom_components/attacks/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/attacks/__init__.py rename to gnn_toolbox/custom_components/attacks/__init__.py diff --git a/gnn_toolbox/custom_modules/attacks/base_attack.py b/gnn_toolbox/custom_components/attacks/base_attack.py similarity index 60% rename from gnn_toolbox/custom_modules/attacks/base_attack.py rename to gnn_toolbox/custom_components/attacks/base_attack.py index 47348c9..0b1322c 100644 --- a/gnn_toolbox/custom_modules/attacks/base_attack.py +++ b/gnn_toolbox/custom_components/attacks/base_attack.py @@ -16,6 +16,7 @@ from torchtyping import TensorType, patch_typeguard from typeguard import typechecked import logging +import inspect import numpy as np import scipy.sparse as sp @@ -26,8 +27,6 @@ from gnn_toolbox.registry import registry, get_from_registry -# patch_typeguard() - @typechecked class BaseAttack(ABC): """ @@ -38,7 +37,7 @@ class BaseAttack(ABC): adj : SparseTensor or torch.Tensor [n, n] (sparse) adjacency matrix. attr : torch.Tensor - [n, d] feature/attribute matrix. + [n, d] feature/attribute matrix. labels : torch.Tensor Labels vector of shape [n]. idx_attack : np.ndarray @@ -67,73 +66,76 @@ def __init__( loss_type: str = "CE", **kwargs, ): - self.device = device - self.idx_attack = idx_attack + self.attr = attr.to(self.device) # unperturbed attributes + self.adj = adj.to(self.device) # unperturbed adjacency + + self.attr_adversary = self.attr # perturbed attributes + self.adj_adversary = self.adj # adjacency matrix that can be perturbed + + self.idx_attack = idx_attack + self.labels = labels.to(torch.long).to(self.device) + self.labels_attack = self.labels[self.idx_attack] + self.loss_type = loss_type - self.make_undirected = make_undirected self.attacked_model = deepcopy(model).to(self.device) self.attacked_model.eval() for params in self.attacked_model.parameters(): params.requires_grad = False - self.eval_model = self.attacked_model - self.labels = labels.to(torch.long).to(self.device) - self.labels_attack = self.labels[self.idx_attack] - self.attr = attr.to(self.device) - self.adj = adj.to(self.device) + + self.eval_model = self.attacked_model - self.attr_adversary = self.attr - self.adj_adversary = self.adj @abstractmethod - def _attack(self, n_perturbations: int, **kwargs): - pass - def attack(self, n_perturbations: int, **kwargs): - """ - Executes the attack on the model updating the attributes - self.adj_adversary and self.attr_adversary accordingly. + pass - Parameters - ---------- - n_perturbations : int - number of perturbations (attack budget in terms of node additions/deletions) that constrain the atack - """ - if n_perturbations > 0: - return self._attack(n_perturbations, **kwargs) - else: - self.attr_adversary = self.attr - self.adj_adversary = self.adj - - def set_pertubations( - self, - adj_perturbed: Union[SparseTensor, TensorType], - attr_perturbed: TensorType, - ): - self.adj_adversary = adj_perturbed.to(self.device) - self.attr_adversary = attr_perturbed.to(self.device) + # def attack(self, n_perturbations: int, **kwargs): + # """ + # Executes the attack on the model updating the attributes + # self.adj_adversary and self.attr_adversary accordingly. + + # Parameters + # ---------- + # n_perturbations : int + # number of perturbations (attack budget in terms of node additions/deletions) that constrain the atack + # """ + # if n_perturbations > 0: + # return self._attack(n_perturbations, **kwargs) + # else: + # self.attr_adversary = self.attr + # self.adj_adversary = self.adj + def get_perturbations(self): adj_adversary, attr_adversary = self.adj_adversary, self.attr_adversary - if isinstance(self.adj_adversary, torch.Tensor): - # * might need to do to_dense() here since torch_geometric uses torch tensor - adj_adversary = SparseTensor.to_torch_sparse_coo_tensor(self.adj_adversary) + # if isinstance(self.adj_adversary, torch.Tensor): + # adj_adversary = SparseTensor.to_torch_sparse_coo_tensor(self.adj_adversary) if isinstance(self.attr_adversary, SparseTensor): attr_adversary = self.attr_adversary.to_dense() - return adj_adversary, attr_adversary def calculate_loss(self, logits, labels): - loss = get_from_registry("losses", self.loss_type, registry) + loss = get_from_registry("loss", self.loss_type, registry) return loss(logits, labels) - - + + def from_sparsetensor_to_edge_index(self, adj): + if isinstance(adj, SparseTensor): + edge_index_rows, edge_index_cols, edge_weight = adj.coo() + edge_index = torch.stack([edge_index_rows, edge_index_cols], dim=0).to(self.device) + return edge_index, edge_weight.to(self.device) + return None, None + + def from_edge_index_to_sparsetensor(self, edge_index, edge_weight): + return SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight).to(self.device) + + @typechecked -class SparseAttack(BaseAttack): +class GlobalAttack(BaseAttack): """ Base class for all sparse attacks. Just like the base attack class but automatically casting the adjacency to sparse format. @@ -141,29 +143,19 @@ class SparseAttack(BaseAttack): def __init__( self, - adj: Union[SparseTensor, TensorType, sp.csr_matrix], + adj: SparseTensor, make_undirected: bool = True, **kwargs, ): - if isinstance(adj, torch.Tensor): - adj = SparseTensor.from_dense(adj) - elif isinstance(adj, sp.csr_matrix): - adj = SparseTensor.from_scipy(adj) - super().__init__(adj, make_undirected=make_undirected, **kwargs) - edge_index_rows, edge_index_cols, edge_weight = adj.coo() - self.edge_index = torch.stack([edge_index_rows, edge_index_cols], dim=0).to( - self.device - ) - self.edge_weight = edge_weight.to(self.device) self.n = adj.size(0) self.d = self.attr.shape[1] - + self.num_nodes = self.attr.shape[0] @typechecked -class SparseLocalAttack(SparseAttack): +class LocalAttack(GlobalAttack): """ Base class for all local sparse attacks """ @@ -180,18 +172,23 @@ def get_logits( self, model, node_idx: int, - perturbed_graph: SparseTensor = None, + perturbed_graph: Union[SparseTensor, None] = None, ): if perturbed_graph is None: perturbed_graph = self.adj - return model(self.attr.to(self.device), perturbed_graph.to(self.device))[ - node_idx : node_idx + 1 - ] + + sig = inspect.signature(model.forward) + if "edge_weight" in sig.parameters or "edge_attr" in sig.parameters: + edge_index, edge_weight = self.from_sparsetensor_to_edge_index(perturbed_graph) + if edge_index is not None and edge_weight is not None: + return model(self.attr, edge_index, edge_weight)[node_idx : node_idx + 1] + raise ValueError("Model requires edge_weight or edge_attr but none provided") + return model(self.attr.to(self.device), perturbed_graph.to(self.device))[node_idx : node_idx + 1] def get_surrogate_logits( self, node_idx: int, - perturbed_graph: SparseTensor = None, + perturbed_graph: Union[SparseTensor, None] = None, ) -> torch.Tensor: return self.get_logits(self.attacked_model, node_idx, perturbed_graph) @@ -199,16 +196,16 @@ def get_eval_logits( self, node_idx: int, perturbed_graph: Optional[ - Union[SparseTensor, Tuple[TensorType[2, "nnz"], TensorType["nnz"]]] + Union[SparseTensor, Tuple[TensorType, TensorType], None] ] = None, ) -> torch.Tensor: return self.get_logits(self.eval_model, node_idx, perturbed_graph) @torch.no_grad() - def evaluate_local(self, node_idx: int): + def evaluate_node(self, node_idx: int): self.attacked_model.eval() - if torch.cuda.is_available(): + if self.device == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() memory = torch.cuda.memory_allocated() / (1024**3) @@ -218,7 +215,7 @@ def evaluate_local(self, node_idx: int): initial_logits = self.get_eval_logits(node_idx) - if torch.cuda.is_available(): + if self.device == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() memory = torch.cuda.memory_allocated() / (1024**3) @@ -229,6 +226,26 @@ def evaluate_local(self, node_idx: int): logits = self.get_eval_logits(node_idx, self.adj_adversary) return logits, initial_logits + @staticmethod + def classification_statistics( + logits: TensorType, label: TensorType + ) -> Dict[str, float]: + logits, label = F.log_softmax(logits.cpu(), dim=-1), label.cpu() + logits = logits[0] + logit_target = logits[label].item() + sorted = logits.argsort() + logit_best_non_target = (logits[sorted[sorted != label][-1]]).item() + confidence_target = np.exp(logit_target) + confidence_non_target = np.exp(logit_best_non_target) + margin = confidence_target - confidence_non_target + return { + "logit_target": logit_target, + "logit_best_non_target": logit_best_non_target, + "confidence_target": confidence_target, + "confidence_non_target": confidence_non_target, + "margin": margin, + } + def set_eval_model(self, model): self.eval_model = model.to(self.device) @@ -292,40 +309,50 @@ def _probability_margin_loss( :rtype: (Tensor) """ prob = F.softmax(prediction, dim=-1) - margin_ = SparseLocalAttack._margin_loss(prob, labels, idx_mask) + margin_ = LocalAttack._margin_loss(prob, labels, idx_mask) return margin_.mean() def adj_adversary_for_poisoning(self): return self.adj_adversary - -class DenseAttack(BaseAttack): - @typechecked - def __init__( - self, - adj: Union[SparseTensor, TensorType], - attr: TensorType, - labels: TensorType, - idx_attack: np.ndarray, - model, - device: Union[str, int, torch.device], - make_undirected: bool = True, - loss_type: str = "CE", - **kwargs, - ): - if isinstance(adj, SparseTensor): - adj = adj.to_dense() - - super().__init__( - adj, - attr, - labels, - idx_attack, - model, - device, - loss_type=loss_type, - make_undirected=make_undirected, - **kwargs, - ) - - self.n = adj.shape[0] +# from torch_geometric.datasets import Planetoid +# from torch_geometric.transforms import ToSparseTensor +# from torch_geometric.utils import dense_to_sparse +# class DenseAttack(BaseAttack): +# @typechecked +# def __init__( +# self, +# adj: Union[SparseTensor, TensorType], +# attr: TensorType, +# labels: TensorType, +# idx_attack: np.ndarray, +# model, +# device: Union[str, int, torch.device], +# make_undirected: bool = True, +# loss_type: str = "CE", +# **kwargs, +# ): +# if isinstance(adj, SparseTensor): +# adj = adj.to_dense() +# # adj = dense_to_sparse(adj) +# # cora = Planetoid(root='datasets', name='Cora',transform=ToSparseTensor(remove_edge_index=False)) +# # data = cora[0] +# # row, col, edge_attr = data.adj_t.t().coo() +# # edge_index = torch.stack([row, col], dim=0) +# # adj = edge_index +# # adj = data.adj_t.to_dense() +# # ad + +# super().__init__( +# adj, +# attr, +# labels, +# idx_attack, +# model, +# device, +# loss_type=loss_type, +# make_undirected=make_undirected, +# **kwargs, +# ) + +# self.n = adj.shape[0] diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/__init__.py b/gnn_toolbox/custom_components/attacks/global_attacks/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/attacks/global_attacks/__init__.py rename to gnn_toolbox/custom_components/attacks/global_attacks/__init__.py diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/dice.py b/gnn_toolbox/custom_components/attacks/global_attacks/dice.py similarity index 97% rename from gnn_toolbox/custom_modules/attacks/global_attacks/dice.py rename to gnn_toolbox/custom_components/attacks/global_attacks/dice.py index 13feb1b..d019267 100644 --- a/gnn_toolbox/custom_modules/attacks/global_attacks/dice.py +++ b/gnn_toolbox/custom_components/attacks/global_attacks/dice.py @@ -5,12 +5,12 @@ from torch_sparse import SparseTensor from tqdm import tqdm -import gnn_toolbox.custom_modules.utils as utils -from gnn_toolbox.custom_modules.attacks.base_attack import SparseAttack +import gnn_toolbox.custom_components.utils as utils +from gnn_toolbox.custom_components.attacks.base_attack import GlobalAttack from gnn_toolbox.registry import register_global_attack @register_global_attack("DICE") -class DICE(SparseAttack): +class DICE(GlobalAttack): """DICE Attack Parameters @@ -140,7 +140,7 @@ def _from_dict_to_sparse(self, adj_dict): edge_attr=edge_attr, sparse_sizes=torch.Size([self.n, self.n])) - def _attack(self, + def attack(self, n_perturbations: int, **kwargs): add_budget = int(n_perturbations * self.add_ratio) diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/greedy_rbcd.py b/gnn_toolbox/custom_components/attacks/global_attacks/greedy_rbcd.py similarity index 97% rename from gnn_toolbox/custom_modules/attacks/global_attacks/greedy_rbcd.py rename to gnn_toolbox/custom_components/attacks/global_attacks/greedy_rbcd.py index 84fcd79..bcd0551 100644 --- a/gnn_toolbox/custom_modules/attacks/global_attacks/greedy_rbcd.py +++ b/gnn_toolbox/custom_components/attacks/global_attacks/greedy_rbcd.py @@ -3,8 +3,8 @@ import torch_sparse from torch_sparse import SparseTensor -from gnn_toolbox.custom_modules import utils -from gnn_toolbox.custom_modules.attacks.global_attacks.prbcd import PRBCD +from gnn_toolbox.custom_components import utils +from gnn_toolbox.custom_components.attacks.global_attacks.prbcd import PRBCD from gnn_toolbox.registry import register_global_attack @register_global_attack("GreedyRBCD") diff --git a/gnn_toolbox/custom_components/attacks/global_attacks/new_prbcd.py b/gnn_toolbox/custom_components/attacks/global_attacks/new_prbcd.py index be9a938..001ec56 100644 --- a/gnn_toolbox/custom_components/attacks/global_attacks/new_prbcd.py +++ b/gnn_toolbox/custom_components/attacks/global_attacks/new_prbcd.py @@ -7,16 +7,18 @@ import torch.nn.functional as F from torch import Tensor from tqdm import tqdm -from torch_sparse import SparseTensor +# from torch_sparse import SparseTensor +import torch_sparse +import scipy.sparse as sp from torch_geometric.utils import coalesce, to_undirected -from gnn_toolbox.custom_components.attacks.base_attack import SparseAttack +from gnn_toolbox.custom_components.attacks.base_attack import GlobalAttack from gnn_toolbox.registry import register_global_attack # (predictions, labels, ids/mask) -> Tensor with one element LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] @register_global_attack("PRBCD2") -class PRBCDAttack(SparseAttack): +class PRBCDAttack(GlobalAttack): r"""The Projected Randomized Block Coordinate Descent (PRBCD) adversarial attack from the `Robustness of Graph Neural Networks at Scale `_ paper. @@ -131,7 +133,7 @@ def __init__( self.coeffs.update(kwargs) - def _attack(self, n_perturbations: int, **kwargs): + def attack(self, n_perturbations: int, **kwargs): row, col, edge_attr = self.adj_adversary.coo() edge_index = torch.stack([row, col], dim=0) self.attacker(self.attr, edge_index, self.labels, n_perturbations, **kwargs) @@ -195,7 +197,11 @@ def attacker( assert flipped_edges.size(1) <= budget, ( f'# perturbed edges {flipped_edges.size(1)} ' f'exceeds budget {budget}') - self.adj_adversary = SparseTensor.from_edge_index(perturbed_edge_index, edge_weight, (self.n, self.n)) + # self.edge_index_adversary = perturbed_edge_index + edge_weight = torch.ones(perturbed_edge_index.size(1), device=self.device) + # adj = sp.csr_matrix((edge_weight.cpu(), perturbed_edge_index.cpu()), (self.num_nodes, self.num_nodes)) + # self.adj_adversary = torch_sparse.SparseTensor.from_scipy(adj).coalesce().to(self.device) + self.adj_adversary = torch_sparse.SparseTensor(row = perturbed_edge_index[0].to(self.device), col = perturbed_edge_index[1].to(self.device), value = edge_weight) # self.adj_adversary = perturbed_edge_index # return perturbed_edge_index, flipped_edges diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/prbcd.py b/gnn_toolbox/custom_components/attacks/global_attacks/prbcd.py similarity index 96% rename from gnn_toolbox/custom_modules/attacks/global_attacks/prbcd.py rename to gnn_toolbox/custom_components/attacks/global_attacks/prbcd.py index c786644..9277abe 100644 --- a/gnn_toolbox/custom_modules/attacks/global_attacks/prbcd.py +++ b/gnn_toolbox/custom_components/attacks/global_attacks/prbcd.py @@ -11,12 +11,12 @@ from torch_sparse import SparseTensor # from rgnn_at_scale.models import MODEL_TYPE -from gnn_toolbox.custom_modules import utils -from gnn_toolbox.custom_modules.attacks.base_attack import SparseAttack +from gnn_toolbox.custom_components import utils +from gnn_toolbox.custom_components.attacks.base_attack import GlobalAttack from gnn_toolbox.registry import register_global_attack @register_global_attack("PRBCD") -class PRBCD(SparseAttack): +class PRBCD(GlobalAttack): """Sampled and hence scalable PGD attack for graph data. """ @@ -93,7 +93,7 @@ def _attack(self, n_perturbations, **kwargs): # Loop over the epochs (Algorithm 1, line 5) for epoch in tqdm(range(self.epochs)): self.perturbed_edge_weight.requires_grad = True - + # self.perturbed_edge_weight.retain_grad() # Retreive sparse perturbed adjacency matrix `A \oplus p_{t-1}` (Algorithm 1, line 6) edge_index, edge_weight = self.get_modified_adj() @@ -105,9 +105,13 @@ def _attack(self, n_perturbations, **kwargs): logits = self._get_logits(self.attr, edge_index, edge_weight) # Calculate loss combining all each node (Algorithm 1, line 7) loss = self.calculate_loss(logits[self.idx_attack], self.labels[self.idx_attack]) + # self.perturbed_edge_weight.requires_grad = True + # self.perturbed_edge_weight.retain_grad() + # loss.requires_grad = True + # loss.retain_grad() # Retreive gradient towards the current block (Algorithm 1, line 7) gradient = utils.grad_with_checkpoint(loss, self.perturbed_edge_weight)[0] - + # gradient = torch.autograd.grad(loss, self.perturbed_edge_weight, allow_unused=True)[0] if torch.cuda.is_available() and self.do_synchronize: torch.cuda.empty_cache() torch.cuda.synchronize() @@ -125,7 +129,7 @@ def _attack(self, n_perturbations, **kwargs): # Calculate accuracy after the current epoch (overhead for monitoring and early stopping) edge_index, edge_weight = self.get_modified_adj() - logits = self.attacked_model(data=self.attr.to(self.device), adj=(edge_index, edge_weight)) + logits = self.attacked_model(self.attr.to(self.device), edge_index, edge_weight) accuracy = utils.accuracy(logits, self.labels, self.idx_attack) del edge_index, edge_weight, logits @@ -175,7 +179,8 @@ def _attack(self, n_perturbations, **kwargs): def _get_logits(self, attr: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor): return self.attacked_model( attr.to(self.device), - adj=(edge_index.to(self.device), edge_weight.to(self.device)) + edge_index.float().to(self.device), + edge_weight.to(self.device) ) @torch.no_grad() diff --git a/gnn_toolbox/custom_modules/attacks/local_attacks/__init__.py b/gnn_toolbox/custom_components/attacks/local_attacks/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/attacks/local_attacks/__init__.py rename to gnn_toolbox/custom_components/attacks/local_attacks/__init__.py diff --git a/gnn_toolbox/custom_modules/attacks/local_attacks/local_dice.py b/gnn_toolbox/custom_components/attacks/local_attacks/local_dice.py similarity index 95% rename from gnn_toolbox/custom_modules/attacks/local_attacks/local_dice.py rename to gnn_toolbox/custom_components/attacks/local_attacks/local_dice.py index 0723ad9..f19b865 100644 --- a/gnn_toolbox/custom_modules/attacks/local_attacks/local_dice.py +++ b/gnn_toolbox/custom_components/attacks/local_attacks/local_dice.py @@ -1,11 +1,11 @@ import torch from torch_sparse import SparseTensor -from gnn_toolbox.custom_modules.attacks.base_attack import SparseLocalAttack +from gnn_toolbox.custom_components.attacks.base_attack import LocalAttack from gnn_toolbox.registry import register_local_attack -@register_local_attack("localDICE") -class LocalDICE(SparseLocalAttack): +@register_local_attack("LocalDICE") +class LocalDICE(LocalAttack): """A Local version of the DICE Attack Parameters @@ -22,7 +22,7 @@ def __init__(self, add_ratio: float = 1.0, **kwargs): self.add_ratio = add_ratio - def _attack(self, + def attack(self, n_perturbations: int, node_idx: int, **kwargs): diff --git a/gnn_toolbox/custom_modules/attacks/local_attacks/local_prbcd.py b/gnn_toolbox/custom_components/attacks/local_attacks/local_prbcd.py similarity index 93% rename from gnn_toolbox/custom_modules/attacks/local_attacks/local_prbcd.py rename to gnn_toolbox/custom_components/attacks/local_attacks/local_prbcd.py index 946f1db..697883d 100644 --- a/gnn_toolbox/custom_modules/attacks/local_attacks/local_prbcd.py +++ b/gnn_toolbox/custom_components/attacks/local_attacks/local_prbcd.py @@ -12,14 +12,13 @@ from tqdm import tqdm -from gnn_toolbox.custom_modules.attacks.base_attack import SparseLocalAttack -from gnn_toolbox.custom_modules import utils +from gnn_toolbox.custom_components.attacks.base_attack import LocalAttack +from gnn_toolbox.custom_components import utils from gnn_toolbox.registry import register_local_attack @register_local_attack("LocalPRBCD") -class LocalPRBCD(SparseLocalAttack): - +class LocalPRBCD(LocalAttack): @typechecked def __init__(self, loss_type: str = 'Margin', @@ -55,7 +54,7 @@ def __init__(self, self.lr_factor = lr_factor self.lr_factor *= max(math.sqrt(self.n / self.block_size), 1.) - def _attack(self, n_perturbations: int, node_idx: int, **kwargs): + def attack(self, n_perturbations: int, node_idx: int, **kwargs): self.sample_search_space(node_idx, n_perturbations) best_margin = float('Inf') @@ -66,7 +65,7 @@ def _attack(self, n_perturbations: int, node_idx: int, **kwargs): logits_orig = self.get_surrogate_logits(node_idx).to(self.device) loss_orig = self.calculate_loss(logits_orig, self.labels[node_idx, None]).to(self.device) statistics_orig = LocalPRBCD.classification_statistics(logits_orig, self.labels[node_idx]) - logging.info(f'Original: Loss: {loss_orig.item()} Statstics: {statistics_orig}\n') + logging.debug(f'Original: Loss: {loss_orig.item()} Statistics: {statistics_orig}\n') del logits_orig del loss_orig @@ -84,7 +83,7 @@ def _attack(self, n_perturbations: int, node_idx: int, **kwargs): if epoch == 0: classification_statistics = LocalPRBCD.classification_statistics( logits, self.labels[node_idx].to(self.device)) - logging.info(f'Initial: Loss: {loss.item()} Statstics: {classification_statistics}\n') + logging.debug(f'Initial: Loss: {loss.item()} Statistics: {classification_statistics}\n') gradient = utils.grad_with_checkpoint(loss, self.modified_edge_weight_diff)[0] @@ -105,11 +104,11 @@ def _attack(self, n_perturbations: int, node_idx: int, **kwargs): classification_statistics = LocalPRBCD.classification_statistics( logits, self.labels[node_idx].to(self.device)) if epoch % self.display_step == 0: - logging.info(f'\nEpoch: {epoch} Loss: {loss.item()} Statstics: {classification_statistics}\n') - logging.info(f"Gradient mean {gradient.abs().mean().item()} std {gradient.abs().std().item()} " + logging.debug(f'\nEpoch: {epoch} Loss: {loss.item()} Statstics: {classification_statistics}\n') + logging.debug(f"Gradient mean {gradient.abs().mean().item()} std {gradient.abs().std().item()} " f"with base learning rate {n_perturbations * self.lr_factor}") if torch.cuda.is_available(): - logging.info(f'Cuda memory {torch.cuda.memory_allocated() / (1024 ** 3)}') + logging.debug(f'Cuda memory {torch.cuda.memory_allocated() / (1024 ** 3)}') if self.with_early_stopping and best_margin > classification_statistics['margin']: best_margin = classification_statistics['margin'] @@ -122,7 +121,7 @@ def _attack(self, n_perturbations: int, node_idx: int, **kwargs): if epoch < self.epochs_resampling - 1: self.resample_search_space(node_idx, n_perturbations, gradient) elif self.with_early_stopping and epoch == self.epochs_resampling - 1: - logging.info( + logging.debug( f'Loading search space of epoch {best_epoch} (margin={best_margin}) for fine tuning\n') self.current_search_space = best_search_space.clone().to(self.device) self.modified_edge_weight_diff = best_edge_weight_diff.clone().to(self.device) @@ -136,7 +135,7 @@ def _attack(self, n_perturbations: int, node_idx: int, **kwargs): self.perturbed_edges = torch.tensor([]) self.adj_adversary = None self.attr_adversary = self.attr - logging.info(f"Failed to attack node {node_idx} with n_perturbations={n_perturbations}") + logging.debug(f"Failed to attack node {node_idx} with n_perturbations={n_perturbations}") return None if self.with_early_stopping: @@ -301,4 +300,4 @@ def sample_final_edges(self, node_idx: int, n_perturbations: int) -> SparseTenso self.current_search_space = best_search_space.to(self.device).long() perturbed_graph = self.perturb_graph(node_idx) - return perturbed_graph + return perturbed_graph \ No newline at end of file diff --git a/gnn_toolbox/custom_modules/data_transforms/__init__.py b/gnn_toolbox/custom_components/data_transforms/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/data_transforms/__init__.py rename to gnn_toolbox/custom_components/data_transforms/__init__.py diff --git a/gnn_toolbox/custom_modules/data_transforms/transforms.py b/gnn_toolbox/custom_components/data_transforms/transforms.py similarity index 100% rename from gnn_toolbox/custom_modules/data_transforms/transforms.py rename to gnn_toolbox/custom_components/data_transforms/transforms.py diff --git a/gnn_toolbox/custom_modules/datasets/__init__.py b/gnn_toolbox/custom_components/losses/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/datasets/__init__.py rename to gnn_toolbox/custom_components/losses/__init__.py diff --git a/gnn_toolbox/custom_modules/losses/losses.py b/gnn_toolbox/custom_components/losses/losses.py similarity index 100% rename from gnn_toolbox/custom_modules/losses/losses.py rename to gnn_toolbox/custom_components/losses/losses.py diff --git a/gnn_toolbox/custom_modules/losses/__init__.py b/gnn_toolbox/custom_components/models/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/losses/__init__.py rename to gnn_toolbox/custom_components/models/__init__.py diff --git a/gnn_toolbox/custom_components/models/architecture.py b/gnn_toolbox/custom_components/models/architecture.py new file mode 100644 index 0000000..78b6942 --- /dev/null +++ b/gnn_toolbox/custom_components/models/architecture.py @@ -0,0 +1,205 @@ +import torch +import torch.nn as nn +import os +import torch.nn.functional as F +from torch_geometric.nn import GATConv, GCNConv, GCN, GAT +from torch_geometric.nn.conv.gcn_conv import gcn_norm +# from gnn_toolbox.custom_modules.models.model +# import gnn_toolbox +from gnn_toolbox.registry import register_model, registry +import os +from typing import Any, Dict, Union + + + +@register_model("GCN") +class GCNWH(nn.Module): + def __init__(self, in_channels: int, out_channels: int, hidden_channels: int, num_layers: int = 2, dropout: float = 0.5, **kwargs): + super().__init__() + self.GCNConv1 = GCNConv(in_channels=in_channels, out_channels=hidden_channels) + self.layers = nn.ModuleList([GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)]) + self.dropout = nn.Dropout(dropout) + self.GCNConv2 = GCNConv(in_channels=hidden_channels, out_channels=out_channels) + + def forward(self, x, edge_index, edge_weight, **kwargs): + x = self.GCNConv1(x, edge_index, edge_weight).relu() + x = F.dropout(x, training=self.training) + x = self.GCNConv2(x, edge_index, edge_weight) + return x + + # def compute_loss(self, pred, true): + # return F.binary_cross_entropy_with_logits(pred, true), torch.sigmoid(pred) + + + +@register_model('gcn2') +def GCN_(in_channels, out_channels, hidden_channels): + return GCN(in_channels, out_channels, hidden_channels) + +# @register_model('gat') +# def GAT + +@register_model('gcn3') +class GCN2(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): + super().__init__() + self.norm = gcn_norm + self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False) + self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False) + + def reset_parameters(self): + self.conv1.reset_parameters() + self.conv2.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None, **kwargs): + # Normalize edge indices only once: + if not kwargs.get('skip_norm', False): + edge_index, edge_weight = self.norm( + edge_index, + edge_weight, + num_nodes=x.size(0), + add_self_loops=True, + ) + + x = self.conv1(x, edge_index, edge_weight).relu() + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index, edge_weight) + return x + + def _extract_model_info(self): + layers = list(self.modules()) + for idx, m in enumerate(self.modules()): + print(idx, '->', m) + +# model = GCN(16, 16, 16) +# print(model.hparams) + +# parameters = list(self.modules()) +# model._extract_model_info() + +# MODEL_TYPE = Union[GCN, GCN2] +from typing import Union, Tuple +class DenseGraphConvolution(nn.Module): + """Dense GCN convolution layer for the FGSM attack that requires a gradient towards the adjacency matrix. + """ + + def __init__(self, in_channels: int, out_channels: int): + """ + Parameters + ---------- + in_channels : int + Number of channels of the input + out_channels : int + Desired number of channels for the output (for trainable linear transform) + """ + super().__init__() + self._linear = nn.Linear(in_channels, out_channels, bias=False) + + def forward(self, arguments: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Prediction based on input. + + Parameters + ---------- + arguments : Tuple[torch.Tensor, torch.Tensor] + Tuple with two elements of the attributes and dense adjacency matrix + + Returns + ------- + torch.Tensor + The new embeddings + """ + x, adj_matrix = arguments + + x_trans = self._linear(x) + return adj_matrix @ x_trans + +import collections +from torch_sparse import coalesce, SparseTensor +@register_model('DenseGCN') +class DenseGCN(nn.Module): + """Dense two layer GCN for the FGSM attack that requires a gradient towards the adjacency matrix. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + activation: nn.Module = nn.ReLU(), + dropout: float = 0.5, + ** kwargs): + """ + Parameters + ---------- + n_features : int + Number of attributes for each node + n_classes : int + Number of classes for prediction + n_filters : int, optional + number of dimensions for the hidden units, by default 80 + activation : nn.Module, optional + Arbitrary activation function for the hidden layer, by default nn.ReLU() + dropout : int, optional + Dropout rate, by default 0.5 + """ + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.activation = activation + self.dropout = dropout + self.layers = nn.ModuleList([ + nn.Sequential(collections.OrderedDict([ + ('gcn_0', DenseGraphConvolution(in_channels=in_channels, + out_channels=hidden_channels)), + ('activation_0', self.activation), + ('dropout_0', nn.Dropout(p=dropout)) + ])), + nn.Sequential(collections.OrderedDict([ + ('gcn_1', DenseGraphConvolution(in_channels=hidden_channels, + out_channels=out_channels)), + ('softmax_1', nn.LogSoftmax(dim=1)) + ])) + ]) + + @ staticmethod + def normalize_dense_adjacency_matrix(adj: torch.Tensor) -> torch.Tensor: + """Normalizes the adjacency matrix as proposed for a GCN by Kipf et al. Moreover, it only uses the upper triangular + matrix of the input to obtain the right gradient towards the undirected adjacency matrix. + + Parameters + ---------- + adj: torch.Tensor + The weighted undirected [n x n] adjacency matrix. + + Returns + ------- + torch.Tensor + Normalized [n x n] adjacency matrix. + """ + adj_norm = torch.triu(adj, diagonal=1) + torch.triu(adj, diagonal=1).T + adj_norm.data[torch.arange(adj.shape[0]), torch.arange(adj.shape[0])] = 1 + deg = torch.diag(torch.pow(adj_norm.sum(axis=1), - 1 / 2)) + adj_norm = deg @ adj_norm @ deg + return adj_norm + + def forward(self, x: torch.Tensor, adjacency_matrix: Union[torch.Tensor, SparseTensor]) -> torch.Tensor: + """Prediction based on input. + + Parameters + ---------- + x : torch.Tensor + Dense [n, d] tensor holding the attributes + adjacency_matrix : torch.Tensor + Dense [n, n] tensor for the adjacency matrix + + Returns + ------- + torch.Tensor + The predictions (after applying the softmax) + """ + if isinstance(adjacency_matrix, SparseTensor): + adjacency_matrix = adjacency_matrix.to_dense() + adjacency_matrix = DenseGCN.normalize_dense_adjacency_matrix(adjacency_matrix) + for layer in self.layers: + x = layer((x, adjacency_matrix)) + return x \ No newline at end of file diff --git a/gnn_toolbox/custom_components/models/base_model.py b/gnn_toolbox/custom_components/models/base_model.py new file mode 100644 index 0000000..e34330a --- /dev/null +++ b/gnn_toolbox/custom_components/models/base_model.py @@ -0,0 +1,52 @@ +# class BaseModel(): +# @ staticmethod +# def parse_forward_input(data: Optional[Union[Data, TensorType["n_nodes", "n_features"]]] = None, +# adj: Optional[Union[SparseTensor, +# torch.sparse.FloatTensor, +# Tuple[TensorType[2, "nnz"], TensorType["nnz"]], +# TensorType["n_nodes", "n_nodes"]]] = None, +# attr_idx: Optional[TensorType["n_nodes", "n_features"]] = None, +# edge_idx: Optional[TensorType[2, "nnz"]] = None, +# edge_weight: Optional[TensorType["nnz"]] = None, +# n: Optional[int] = None, +# d: Optional[int] = None) -> Tuple[TensorType["n_nodes", "n_features"], +# TensorType[2, "nnz"], +# TensorType["nnz"]]: +# edge_weight = None +# # PyTorch Geometric support +# if isinstance(data, Data): +# x, edge_idx = data.x, data.edge_index +# # Randomized smoothing support +# elif attr_idx is not None and edge_idx is not None and n is not None and d is not None: +# x = coalesce(attr_idx, torch.ones_like(attr_idx[0], dtype=torch.float32), m=n, n=d) +# x = torch.sparse.FloatTensor(x[0], x[1], torch.Size([n, d])).to_dense() +# edge_idx = edge_idx +# # Empirical robustness support +# elif isinstance(adj, tuple): +# # Necessary since `torch.sparse.FloatTensor` eliminates the gradient... +# x, edge_idx, edge_weight = data, adj[0], adj[1] +# elif isinstance(adj, SparseTensor): +# x = data + +# edge_idx_rows, edge_idx_cols, edge_weight = adj.coo() +# edge_idx = torch.stack([edge_idx_rows, edge_idx_cols], dim=0) +# else: +# if not adj.is_sparse: +# adj = adj.to_sparse() + +# x, edge_idx, edge_weight = data, adj._indices(), adj._values() + +# if edge_weight is None: +# edge_weight = torch.ones_like(edge_idx[0], dtype=torch.float32) + +# if edge_weight.dtype != torch.float32: +# edge_weight = edge_weight.float() + +# return x, edge_idx, edge_weight + + + + +# model(x, edge_idx, edge_weight) + +# depending on the attack, \ No newline at end of file diff --git a/gnn_toolbox/custom_modules/models/model.py b/gnn_toolbox/custom_components/models/model.py similarity index 97% rename from gnn_toolbox/custom_modules/models/model.py rename to gnn_toolbox/custom_components/models/model.py index 5d314a6..5a0c33e 100644 --- a/gnn_toolbox/custom_modules/models/model.py +++ b/gnn_toolbox/custom_components/models/model.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningModule from gnn_toolbox.old.config_def import cfg -from gnn_toolbox.custom_modules.optimizers.optimizers import register_optimizer +from gnn_toolbox.custom_components.optimizers.optimizers import register_optimizer from gnn_toolbox.registry import registry class BaseModel(LightningModule): diff --git a/gnn_toolbox/custom_modules/models/__init__.py b/gnn_toolbox/custom_components/optimizers/__init__.py similarity index 100% rename from gnn_toolbox/custom_modules/models/__init__.py rename to gnn_toolbox/custom_components/optimizers/__init__.py diff --git a/gnn_toolbox/custom_modules/optimizers/optimizers.py b/gnn_toolbox/custom_components/optimizers/optimizers.py similarity index 100% rename from gnn_toolbox/custom_modules/optimizers/optimizers.py rename to gnn_toolbox/custom_components/optimizers/optimizers.py diff --git a/gnn_toolbox/custom_modules/utils.py b/gnn_toolbox/custom_components/utils.py similarity index 76% rename from gnn_toolbox/custom_modules/utils.py rename to gnn_toolbox/custom_components/utils.py index cff072c..63868a5 100644 --- a/gnn_toolbox/custom_modules/utils.py +++ b/gnn_toolbox/custom_components/utils.py @@ -2,7 +2,7 @@ import scipy.sparse as sp from torch_sparse import SparseTensor, coalesce from typing import Tuple, Union, Sequence - +import numpy as np class DotDict(dict): """ @@ -38,12 +38,21 @@ def grad_with_checkpoint(outputs: Union[torch.Tensor, Sequence[torch.Tensor]], inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) for input in inputs: + # input.retain_grad() if not input.is_leaf: input.retain_grad() - + torch.autograd.backward(outputs) grad_outputs = [] + # for input in inputs: + # if input.grad is not None: + # grad_outputs.append(input.grad.clone()) + # input.grad.zero_() + # else: + # # Append zeros in the same shape as the input if no gradient was computed + # grad_outputs.append(torch.ones_like(input)) + for input in inputs: grad_outputs.append(input.grad.clone()) input.grad.zero_() @@ -112,4 +121,23 @@ def sparse_tensor(spmat: sp.spmatrix, grad: bool = False): dtype = torch.uint8 else: dtype = torch.float32 - return SparseTensor.from_scipy(spmat).to(dtype).coalesce() \ No newline at end of file + return SparseTensor.from_scipy(spmat).to(dtype).coalesce() + +def accuracy(logits: torch.Tensor, labels: torch.Tensor, split_idx: np.ndarray) -> float: + """Returns the accuracy for a tensor of logits, a list of lables and and a split indices. + + Parameters + ---------- + prediction : torch.Tensor + [n x c] tensor of logits (`.argmax(1)` should return most probable class). + labels : torch.Tensor + [n x 1] target label. + split_idx : np.ndarray + [?] array with indices for current split. + + Returns + ------- + float + the Accuracy + """ + return (logits.argmax(1)[split_idx] == labels[split_idx]).float().mean().item() diff --git a/gnn_toolbox/custom_modules/datasets/datasets.py b/gnn_toolbox/custom_modules/datasets/datasets.py deleted file mode 100644 index 51c3367..0000000 --- a/gnn_toolbox/custom_modules/datasets/datasets.py +++ /dev/null @@ -1,27 +0,0 @@ -from torch_geometric.datasets import ( - Planetoid, - TUDataset, - CoraFull, - Amazon, - Coauthor, - PPI, - Reddit, - GNNBenchmarkDataset -) -from ogb.nodeproppred import PygNodePropPredDataset -from gnn_toolbox.registry import register_dataset - - -register_dataset('Cora', Planetoid) -register_dataset('Citeseer', Planetoid) -register_dataset('Pubmed', Planetoid) -register_dataset('CoraFull', CoraFull) -register_dataset('PPI', PPI) -register_dataset('Amazon', Amazon) -register_dataset('Coauthor', Coauthor) -register_dataset('Reddit', Reddit) -register_dataset('GNNBenchmarkDataset', GNNBenchmarkDataset) -register_dataset('ogb-arxiv', PygNodePropPredDataset) - - - diff --git a/gnn_toolbox/custom_modules/models/architecture.py b/gnn_toolbox/custom_modules/models/architecture.py deleted file mode 100644 index 58fa29d..0000000 --- a/gnn_toolbox/custom_modules/models/architecture.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn as nn -import os -import torch.nn.functional as F -from torch_geometric.nn import GATConv, GCNConv -from torch_geometric.nn.conv.gcn_conv import gcn_norm -# from gnn_toolbox.custom_modules.models.model -# import gnn_toolbox -from gnn_toolbox.registry import register_model, registry -import os -from typing import Any, Dict, Union - - -@register_model("GCN") -class GCN(nn.Module): - def __init__(self, in_channels: int, out_channels: int, hidden_channels: int, num_layers: int = 2, dropout: float = 0.5, **kwargs): - super().__init__() - self.GCNConv = GCNConv(in_channels=in_channels, out_channels=hidden_channels) - self.layers = nn.ModuleList([GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)]) - self.dropout = nn.Dropout(dropout) - self.linear = nn.Linear(hidden_channels, out_channels) - - def forward(self, x, edge_index, **kwargs): - x = self.GCNConv(x, edge_index) - for layer in self.layers: - x = F.relu(layer(x, edge_index)) - x = self.dropout(x) - return self.linear(x) - - # def compute_loss(self, pred, true): - # return F.binary_cross_entropy_with_logits(pred, true), torch.sigmoid(pred) - - -@register_model('gcn2') -class GCN2(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): - super().__init__() - self.norm = gcn_norm - self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False) - self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False) - - def reset_parameters(self): - self.conv1.reset_parameters() - self.conv2.reset_parameters() - - def forward(self, x, edge_index, edge_weight=None, **kwargs): - # Normalize edge indices only once: - if not kwargs.get('skip_norm', False): - edge_index, edge_weight = self.norm( - edge_index, - edge_weight, - num_nodes=x.size(0), - add_self_loops=True, - ) - - x = self.conv1(x, edge_index, edge_weight).relu() - x = F.dropout(x, training=self.training) - x = self.conv2(x, edge_index, edge_weight) - return x - - def _extract_model_info(self): - layers = list(self.modules()) - for idx, m in enumerate(self.modules()): - print(idx, '->', m) - -# model = GCN(16, 16, 16) -# print(model.hparams) - -# parameters = list(self.modules()) -# model._extract_model_info() - -# MODEL_TYPE = Union[GCN, GCN2] diff --git a/gnn_toolbox/custom_modules/optimizers/__init__.py b/gnn_toolbox/custom_modules/optimizers/__init__.py deleted file mode 100644 index 4263881..0000000 --- a/gnn_toolbox/custom_modules/optimizers/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from os.path import dirname, basename, isfile, join -import glob - -modules = glob.glob(join(dirname(__file__), "*.py")) -__all__ = [ - basename(f)[:-3] for f in modules - if isfile(f) and not f.endswith('__init__.py') -] \ No newline at end of file diff --git a/gnn_toolbox/experiment_handler/config_validator.py b/gnn_toolbox/experiment_handler/config_validator.py index cdc776e..07ff305 100644 --- a/gnn_toolbox/experiment_handler/config_validator.py +++ b/gnn_toolbox/experiment_handler/config_validator.py @@ -91,8 +91,8 @@ class Attack(BaseModel): epsilon: Union[List[Union[PositiveInt, PositiveFloat]], PositiveInt, PositiveFloat] nodes: Optional[List[NonNegativeInt]] = None min_node_degree: Optional[PositiveInt] = None - topk: Optional[PositiveInt] = None - + nodes_topk: Optional[PositiveInt] = None + params: Optional[Dict[str, PARAMS_TYPE]] = {} @model_validator(mode='after') def validate_scope(self): if self.scope == 'local': diff --git a/gnn_toolbox/experiment_handler/exp_runner.py b/gnn_toolbox/experiment_handler/exp_runner.py index 1afbed5..41be0eb 100644 --- a/gnn_toolbox/experiment_handler/exp_runner.py +++ b/gnn_toolbox/experiment_handler/exp_runner.py @@ -1,8 +1,7 @@ -from gnn_toolbox.custom_modules import * +from gnn_toolbox.custom_components import * from gnn_toolbox.common import (prepare_dataset, evaluate_model, train, - classification_statistics, gen_local_attack_nodes) from gnn_toolbox.experiment_handler.create_modules import create_model, create_global_attack, create_local_attack, create_dataset, create_optimizer, create_loss @@ -33,7 +32,7 @@ def run_experiment(experiment, experiment_dir, artifact_manager): attr, adj, labels, split, num_edges = prepare_dataset(dataset, experiment, graph_index, make_undirected) model = create_model(experiment['model']) - + untrained_model_state_dict = model.state_dict() # _, accuracy = evaluate_global(model=model, attr=attr, adj=adj, labels=labels, idx_test=split['test'], device=device) # print('untrained model accuracy', accuracy) @@ -43,7 +42,7 @@ def run_experiment(experiment, experiment_dir, artifact_manager): perturbed_result = None if(experiment['attack']['scope'] == 'global'): if experiment['attack']['type'] == 'poison': - pert_adj, pert_attr = global_attack(experiment['attack'], attr, adj, labels, split['train'], model, device, num_edges, make_undirected) + adversarial_attack, n_perturbations = instantiate_global_attack(experiment['attack'], attr, adj, labels, split['train'], model, device, num_edges, make_undirected) model_path, perturbed_result = artifact_manager.model_exists(experiment, is_unattacked_model=False) @@ -52,9 +51,24 @@ def run_experiment(experiment, experiment_dir, artifact_manager): # model.to(device) # return # else: - perturbed_result = train_and_evaluate(model, pert_attr, pert_adj, attr, adj, labels, split, device, experiment, artifact_manager, is_unattacked_model=False) + try: + adversarial_attack.attack(n_perturbations) + pert_adj, pert_attr = adversarial_attack.get_perturbations() + perturbed_result = train_and_evaluate(model, pert_attr, pert_adj, attr, adj, labels, split, device, experiment, artifact_manager, is_unattacked_model=False) + except Exception as e: + logging.exception(e) + logging.error(f"Global poisoning adversarial attack {experiment['attack']['name']} failed to attack the model {experiment['model']['name']}") + return elif experiment['attack']['type'] == 'evasion': - pert_adj, pert_attr = global_attack(experiment['attack'], attr, adj, labels, split['test'], model, device, num_edges, make_undirected) + adversarial_attack, n_perturbations = instantiate_global_attack(experiment['attack'], attr, adj, labels, split['test'], model, device, num_edges, make_undirected) + + try: + adversarial_attack.attack(n_perturbations) + pert_adj, pert_attr = adversarial_attack.get_perturbations() + except Exception as e: + logging.exception(e) + logging.error(f"Global evasion adversarial attack {experiment['attack']['name']} failed to attack the model {experiment['model']['name']}") + return logits, accuracy = evaluate_model(model=model, attr=pert_attr, adj=pert_adj, labels=labels, idx_test=split['test'], device=device) @@ -64,11 +78,11 @@ def run_experiment(experiment, experiment_dir, artifact_manager): } elif(experiment['attack']['scope'] == 'local'): - perturbed_result = local_attack(experiment, attr, adj, labels, split, model, device, make_undirected) + perturbed_result = execute_local_attack(experiment, attr, adj, labels, split, model, device, make_undirected) all_result = { 'clean_result': clean_result, - 'perturbed_result': perturbed_result, + 'perturbed_result': perturbed_result if perturbed_result is not None else None, } # log to tensorboard return all_result, experiment @@ -86,7 +100,7 @@ def clean_train(current_config, artifact_manager, model, attr, adj, labels, spli return result -def train_and_evaluate(model, train_attr, train_adj, test_attr, test_adj, labels, split, device, current_config, artifact_manager, is_unattacked_model): +def train_and_evaluate(model, train_attr, train_adj, test_attr, test_adj, labels, split, device, current_config, artifact_manager, is_unattacked_model, untrained_model_state_dict=None): model = model.to(device) train_attr = train_attr.to(device) train_adj = train_adj.to(device) @@ -94,7 +108,8 @@ def train_and_evaluate(model, train_attr, train_adj, test_attr, test_adj, labels test_adj = test_adj.to(device) labels = labels.to(device) - + if untrained_model_state_dict is not None: + model.load(untrained_model_state_dict) for module in model.modules(): if hasattr(module, 'reset_parameters'): module.reset_parameters() @@ -121,27 +136,32 @@ def train_and_evaluate(model, train_attr, train_adj, test_attr, test_adj, labels return result -def global_attack(attack_info, attr, adj, labels, idx_attack, model, device, num_edges, make_undirected): +def instantiate_global_attack(attack_info, attr, adj, labels, idx_attack, model, device, num_edges, make_undirected): attack_params = getattr(attack_info, 'params', {}) - attack_model = create_global_attack(attack_info['name'])( - attr=attr, adj=adj, labels=labels, idx_attack=idx_attack, - model=model, device=device, make_undirected=make_undirected, **attack_params) - - n_perturbations = int(round(attack_info['epsilon'] * num_edges)) - attack_model.attack(n_perturbations) - return attack_model.get_perturbations() - - -def local_attack(experiment, attr, adj, labels, split, model, device, make_undirected): + try: + attack_model = create_global_attack(attack_info['name'])( + attr=attr, adj=adj, labels=labels, idx_attack=idx_attack, + model=model, device=device, make_undirected=make_undirected, **attack_params) + n_perturbations = int(round(attack_info['epsilon'] * num_edges)) + # attack_model.attack(n_perturbations) + return attack_model, n_perturbations + except Exception as e: + logging.exception(e) + logging.error(f"Failed to create the global adversarial attack '{attack_info['name']}'.") + +def execute_local_attack(experiment, attr, adj, labels, split, model, device, make_undirected): attack_params = getattr(experiment['attack'], 'params', {}) - attack_model = create_local_attack(experiment['attack']['name'])( - attr=attr, adj=adj, labels=labels, - idx_attack=split['test'], - model=model, device=device, make_undirected=make_undirected, **attack_params) - + try: + attack_model = create_local_attack(experiment['attack']['name'])( + attr=attr, adj=adj, labels=labels, + idx_attack=split['test'], + model=model, device=device, make_undirected=make_undirected, **attack_params) + except Exception as e: + logging.exception(e) + logging.error(f"Failed to create local adversarial attack '{experiment['attack']['name']}'.") + results = [] eps = experiment['attack']['epsilon'] - # nodes = experiment.attack.nodes if 'nodes' not in experiment['attack']: epsilon_inverse = int(1 / eps) @@ -150,7 +170,7 @@ def local_attack(experiment, attr, adj, labels, split, model, device, make_undir # topk = int(experiment.attack.topk) if experiment.attack.topk is not None else 10 - topk = int(experiment['attack'].get('topk', 10)) + topk = int(experiment['attack'].get('nodes_topk', 10)) nodes = gen_local_attack_nodes(attr, adj, labels, model, split['train'], device, topk=topk, min_node_degree=min_node_degree) else: @@ -161,18 +181,18 @@ def local_attack(experiment, attr, adj, labels, split, model, device, make_undir n_perturbations = int((eps * degree).round().item()) if n_perturbations == 0: logging.error( - f"Skipping attack for model '{model}' using {experiment['attack']['name']} with eps {eps} at node {node}.") + f"Number of perturbations is 0 for model {experiment['model']['name']} using {experiment['attack']['name']} with eps {eps} at node {node}. Skipping the attack to node {node}") continue try: attack_model.attack(n_perturbations, node_idx=node) except Exception as e: logging.exception(e) logging.error( - f"Failed to attack model '{model}' using {experiment['attack']['name']} with eps {eps} at node {node}.") + f"Adversarial attack {experiment['attack']['name']} failed to attack the model {experiment['model']['name']} using with eps {eps} at node {node}.") continue - logits, initial_logits = attack_model.evaluate_local(node) - + logits, initial_logits = attack_model.evaluate_node(node) + results.append({ 'node index': node, 'node degree': int(degree.item()), @@ -181,11 +201,11 @@ def local_attack(experiment, attr, adj, labels, split, model, device, make_undir 'perturbed_edges': attack_model.get_perturbed_edges().cpu().numpy().tolist(), 'results before attacking (unperturbed data)': { 'logits': initial_logits.cpu().numpy().tolist(), - **classification_statistics(initial_logits.cpu(), labels[node].long().cpu()) + **attack_model.classification_statistics(initial_logits.cpu(), labels[node].long().cpu()) }, 'results after attacking (perturbed data)': { 'logits': logits.cpu().numpy().tolist(), - **classification_statistics(logits.cpu(), labels[node].long().cpu()) + **attack_model.classification_statistics(logits.cpu(), labels[node].long().cpu()) } }) @@ -208,16 +228,18 @@ def local_attack(experiment, attr, adj, labels, split, model, device, make_undir _ = train(model=victim, attr=perturbed_attr.to(device), adj=perturbed_adj.to(device), labels=labels.to(device), idx_train=split['train'], idx_val=split['valid'], idx_test=split['test'], optimizer=optimizer, loss=loss, **experiment['training']) attack_model.set_eval_model(victim) - logits_poisoning, _ = attack_model.evaluate_local(node) + logits_poisoning, _ = attack_model.evaluate_node(node) attack_model.set_eval_model(model) results[-1].update({ 'results after attacking (perturbed data)': { 'logits': logits_poisoning.cpu().numpy().tolist(), - **classification_statistics(logits_poisoning.cpu(), labels[node].long().cpu()) + **attack_model.classification_statistics(logits_poisoning.cpu(), labels[node].long().cpu()) }, - 'pyg_margin': attack_model._probability_margin_loss(victim(attr.to(device), adj.to(device)),labels, [node]).item() + # 'pyg_margin': attack_model._probability_margin_loss(victim(attr.to(device), adj.to(device)),labels, [node]).item() }) + logging.info(f'Node {node} with perturbed edges evaluated on model {experiment["model"]["name"]} using adversarial attack {experiment["attack"]["name"]} with epsilon {eps}') + logging.debug(results[-1]) assert len(results) > 0, "No attack could be made." return results diff --git a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/dice.py b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/dice.py index b010ae5..4ce6054 100644 --- a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/dice.py +++ b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/dice.py @@ -6,7 +6,7 @@ from tqdm import tqdm from .base_attack import SparseAttack -import gnn_toolbox.custom_modules.utils as utils +import gnn_toolbox.custom_components.utils as utils from gnn_toolbox.registry import register_attack @register_attack("DICE") diff --git a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/greedy_rbcd.py b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/greedy_rbcd.py index 73a2cfe..c33f35a 100644 --- a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/greedy_rbcd.py +++ b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/greedy_rbcd.py @@ -3,7 +3,7 @@ import torch_sparse from torch_sparse import SparseTensor -from gnn_toolbox.custom_modules import utils +from gnn_toolbox.custom_components import utils from gnn_toolbox.attacks.robustness_of_gnns_at_scale.prbcd import PRBCD diff --git a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/local_prbcd.py b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/local_prbcd.py index a2ef6a4..fce3512 100644 --- a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/local_prbcd.py +++ b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/local_prbcd.py @@ -14,7 +14,7 @@ from rgnn_at_scale.helper.utils import grad_with_checkpoint, to_symmetric from rgnn_at_scale.attacks.base_attack import Attack, SparseLocalAttack -from gnn_toolbox.custom_modules import utils +from gnn_toolbox.custom_components import utils class LocalPRBCD(SparseLocalAttack): diff --git a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/prbcd.py b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/prbcd.py index 4e4e796..45b8d64 100644 --- a/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/prbcd.py +++ b/gnn_toolbox/old/attacks/robustness_of_gnns_at_scale/prbcd.py @@ -11,7 +11,7 @@ from torch_sparse import SparseTensor # from rgnn_at_scale.models import MODEL_TYPE -from gnn_toolbox.custom_modules import utils +from gnn_toolbox.custom_components import utils from gnn_toolbox.attacks.base_attack import Attack, SparseAttack diff --git a/gnn_toolbox/old/base.py b/gnn_toolbox/old/base.py index 429c3af..05dd7d9 100644 --- a/gnn_toolbox/old/base.py +++ b/gnn_toolbox/old/base.py @@ -3,10 +3,10 @@ from config_def import cfg, set_run_dir import yaml import logging -from gnn_toolbox.custom_modules.lightning_train import train +from gnn_toolbox.custom_components.lightning_train import train from torch_geometric import seed_everything from gnn_toolbox.old.utils3 import create_model, auto_select_device -from gnn_toolbox.custom_modules.lightning_train import BaseDataModule +from gnn_toolbox.custom_components.lightning_train import BaseDataModule from sacred import Experiment ex = Experiment("Run_experiments") diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/fgsm.py b/gnn_toolbox/old/fgsm.py similarity index 79% rename from gnn_toolbox/custom_modules/attacks/global_attacks/fgsm.py rename to gnn_toolbox/old/fgsm.py index a3e7284..da71398 100644 --- a/gnn_toolbox/custom_modules/attacks/global_attacks/fgsm.py +++ b/gnn_toolbox/old/fgsm.py @@ -5,7 +5,7 @@ import torch from torch_sparse import SparseTensor -from gnn_toolbox.custom_modules.attacks.base_attack import DenseAttack +from gnn_toolbox.custom_components.attacks.base_attack import DenseAttack from gnn_toolbox.registry import register_global_attack @@ -31,11 +31,18 @@ def __init__(self, **kwargs): super().__init__(**kwargs) assert self.make_undirected, 'Attack only implemented for undirected graphs' - + # self.adj_perturbed = self.adj.clone() + # self.adj_perturbed = self.adj.T.clone().requires_grad_(True).to(self.device) self.adj_perturbed = self.adj.clone().requires_grad_(True).to(self.device) self.n_perturbations = 0 - + from torch_geometric.datasets import Planetoid + from torch_geometric.transforms import ToUndirected + cora = Planetoid(root='datasets', name='Cora', transform=ToUndirected()) + data = cora[0] + # self.adj_perturbed = data.edge_index.clone().requires_grad_(True).to(self.device) + self.adj = self.adj.to(self.device) + # self.attr = data.x.to(self.device) self.attr = self.attr.to(self.device) self.attacked_model = self.attacked_model.to(self.device) @@ -55,7 +62,9 @@ def _attack(self, n_perturbations: int): self.n_perturbations += n_perturbations for _ in tqdm(range(n_perturbations)): - logits = self.attacked_model(self.attr, self.adj_perturbed) + print('shape self.adj_perturbed', self.adj_perturbed.shape) + print('shape self.attr', self.attr.shape) + logits = self.attacked_model(self.attr, self.adj_perturbed.data) loss = self.calculate_loss(logits[self.idx_attack], self.labels[self.idx_attack]) diff --git a/gnn_toolbox/old/lightning_train.py b/gnn_toolbox/old/lightning_train.py index c5aaf55..c627a89 100644 --- a/gnn_toolbox/old/lightning_train.py +++ b/gnn_toolbox/old/lightning_train.py @@ -2,7 +2,7 @@ from torch_geometric.data.lightning.datamodule import LightningDataModule from customize.util import get_ckpt_dir from gnn_toolbox.old.config_def import cfg -from gnn_toolbox.custom_modules.models.model import BaseModel +from gnn_toolbox.custom_components.models.model import BaseModel from custom_modules.loader import create_loader from logger import LoggerCallback diff --git a/gnn_toolbox/custom_modules/attacks/local_attacks/nettack.py b/gnn_toolbox/old/nettack.py similarity index 98% rename from gnn_toolbox/custom_modules/attacks/local_attacks/nettack.py rename to gnn_toolbox/old/nettack.py index d9d131d..3483478 100644 --- a/gnn_toolbox/custom_modules/attacks/local_attacks/nettack.py +++ b/gnn_toolbox/old/nettack.py @@ -9,8 +9,8 @@ from torch.nn import Identity from tqdm import tqdm -from gnn_toolbox.custom_modules.utils import sparse_tensor -from gnn_toolbox.custom_modules.attacks.base_attack import SparseLocalAttack +from gnn_toolbox.custom_components.utils import sparse_tensor +from gnn_toolbox.custom_components.attacks.base_attack import LocalAttack from gnn_toolbox.registry import register_local_attack """ @@ -25,7 +25,7 @@ """ @register_local_attack("Nettack") -class Nettack(SparseLocalAttack): +class Nettack(LocalAttack): """Wrapper around the implementation of the method proposed in the paper: 'Adversarial Attacks on Neural Networks for Graph Data' by Daniel Zügner, Amir Akbarnejad and Stephan Günnemann, @@ -46,13 +46,13 @@ class Nettack(SparseLocalAttack): """ def __init__(self, **kwargs): - SparseLocalAttack.__init__(self, **kwargs) + LocalAttack.__init__(self, **kwargs) assert self.make_undirected, 'Attack only implemented for undirected graphs' - assert len(self.attacked_model.layers) == 2, "Nettack supports only 2 Layer Linear GCN as surrogate model" - assert isinstance(self.attacked_model._modules['activation'], Identity), \ - "Nettack only supports Linear GCN as surrogate model" + # assert len(self.attacked_model.layers) == 2, "Nettack supports only 2 Layer Linear GCN as surrogate model" + # assert isinstance(self.attacked_model._modules['act'], Identity), \ + # "Nettack only supports Linear GCN as surrogate model" self.sp_adj = self.adj.to_scipy(layout="csr") self.sp_attr = SparseTensor.from_dense(self.attr).to_scipy(layout="csr") diff --git a/gnn_toolbox/custom_modules/attacks/global_attacks/pgd.py b/gnn_toolbox/old/pgd.py similarity index 97% rename from gnn_toolbox/custom_modules/attacks/global_attacks/pgd.py rename to gnn_toolbox/old/pgd.py index ae9687d..56ec6cb 100644 --- a/gnn_toolbox/custom_modules/attacks/global_attacks/pgd.py +++ b/gnn_toolbox/old/pgd.py @@ -12,7 +12,7 @@ from torch_sparse import SparseTensor from tqdm import tqdm -from gnn_toolbox.custom_modules.attacks.base_attack import DenseAttack +from gnn_toolbox.custom_components.attacks.base_attack import DenseAttack from gnn_toolbox.registry import register_global_attack @register_global_attack("PGD") @@ -72,6 +72,8 @@ def _attack(self, n_perturbations: int, **kwargs): self.attacked_model.eval() for t in tqdm(range(self.epochs)): modified_adj = self.get_modified_adj() + print('shape1', modified_adj.shape) + print('shape2', self.attr.shape) logits = self.attacked_model(self.attr, modified_adj) loss = self.calculate_loss(logits[self.idx_attack], self.labels[self.idx_attack]) adj_grad = torch.autograd.grad(loss, self.adj_changes)[0] diff --git a/gnn_toolbox/registry.py b/gnn_toolbox/registry.py index b7de4e0..8ca3bf3 100644 --- a/gnn_toolbox/registry.py +++ b/gnn_toolbox/registry.py @@ -1,46 +1,68 @@ from typing import Any, Callable, Dict, Union from functools import partial -from gnn_toolbox.custom_modules import * -ModuleType = Any +from gnn_toolbox.custom_components import * # noqa +import inspect -registry: Dict[str, Dict[str, ModuleType]] = { +MODULE_TYPE = Any + +registry: Dict[str, Dict[str, MODULE_TYPE]] = { "model": {}, - "global_attack":{}, - "local_attack":{}, + "global_attack": {}, + "local_attack": {}, "dataset": {}, "transform": {}, "optimizer": {}, "loss": {}, } -def register_module(category: str, key: str, module: ModuleType = None) -> Union[Callable, None]: + +def register_component( + category: str, key: str, component: MODULE_TYPE = None +) -> Union[Callable, None]: """ - Registers a module. + Registers a component. Args: - category (str): The category of the module (e.g., "act", "node_encoder"). - key (str): The name of the module. - module (any, optional): The module. If set to None, will return a decorator. + category (str): The category of the component (e.g., "act", "node_encoder"). + key (str): The name of the component. + component (any, optional): The component. If set to None, will return a decorator. """ if category not in registry: - raise ValueError(f"Category '{category}' is not valid. Please choose from {list(registry.keys())}.") - - if module is not None: + raise ValueError( + f"Category '{category}' is not valid. Please choose from {list(registry.keys())}." + ) + + if component is not None: if key in registry[category]: - raise KeyError(f"Module with '{key}' already defined in category '{category}'") - registry[category][key] = module + raise KeyError( + f"Component with '{key}' already defined in category '{category}'" + ) + if key == 'model': + try: + check_model_signature(component) + except Exception as e: + raise ValueError(f"Failed to validate model signature: {e}") + registry[category][key] = component return - def register_by_decorator(module): - register_module(category, key, module) - return module + def register_by_decorator(component): + register_component(category, key, component) + return component return register_by_decorator -def get_from_registry(category: str, key: str, registry: Dict[str, Dict[str, ModuleType]], default: Any = None) -> Any: - """Retrieve a module from the registry safely with a fallback.""" + +def get_from_registry( + category: str, + key: str, + registry: Dict[str, Dict[str, MODULE_TYPE]], + default: Any = None, +) -> Any: + """Retrieve a component from the registry safely with a fallback.""" if category not in registry: - raise ValueError(f"Category '{category}' is not recognized. Available categories: {list(registry.keys())}") + raise ValueError( + f"Category '{category}' is not recognized. Available categories: {list(registry.keys())}" + ) category_registry = registry[category] if key in category_registry: @@ -49,12 +71,60 @@ def get_from_registry(category: str, key: str, registry: Dict[str, Dict[str, Mod if default is not None: return default else: - raise KeyError(f"Module '{key}' not found in category '{category}'. Available options: {list(category_registry.keys())}") - -register_model = partial(register_module, "model") -register_global_attack = partial(register_module, "global_attack") -register_local_attack = partial(register_module, "local_attack") -register_dataset = partial(register_module, "dataset") -register_transform = partial(register_module, "transform") -register_optimizer = partial(register_module, "optimizer") -register_loss = partial(register_module, "loss") \ No newline at end of file + raise KeyError( + f"Component '{key}' not found in category '{category}'. Available options: {list(category_registry.keys())}" + ) + + +def check_model_signature(model): + sig = inspect.signature(model.forward) + parameters = [ + param.name for param in sig.parameters.values() if param.name != "self" + ] + + allowed_signatures = [ + ["x", "edge_index"], + ["x", "edge_index", "edge_weight"] + ] + + if parameters not in allowed_signatures: + raise TypeError( + f"Invalid forward parameters. Allowed parameters are {allowed_signatures}." + ) + + +# def register_model(func): +# @wraps(func) +# def wrapper(*args, **kwargs): +# # Check the forward method's signature +# sig = inspect.signature(args[0].forward) +# parameters = [ +# param.name for param in sig.parameters.values() if param.name != "self" +# ] + +# # Define allowed signatures +# allowed_signatures = [ +# ["x", "edge_index"], +# ["x", "edge_index", "edge_weight"], +# ["x", "edge_index", "edge_attr"], +# ] + +# # Check if the model's parameters match any of the allowed signatures +# if parameters not in allowed_signatures: +# raise TypeError( +# f"Invalid forward parameters. Allowed parameters are {allowed_signatures}." +# ) + +# # If valid, call the original function (usually the model's initializer) +# # return func(*args, **kwargs) + +# return wrapper +# return + +register_model = partial(register_component, "model") +register_global_attack = partial(register_component, "global_attack") +register_local_attack = partial(register_component, "local_attack") +register_dataset = partial(register_component, "dataset") +register_transform = partial(register_component, "transform") +register_optimizer = partial(register_component, "optimizer") +register_loss = partial(register_component, "loss") diff --git a/main.py b/main.py index 9226e58..3932b0f 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ def main(file): try: experiments_config = load_and_validate_yaml(file) experiments = generate_experiments_from_yaml(experiments_config) - artifact_manager = ArtifactManager('cache') + artifact_manager = ArtifactManager('cache2') logging.info(f'Running {len(experiments)} experiments') for curr_dir, experiment in experiments.items(): result, experiment_cfg = run_experiment(experiment, curr_dir, artifact_manager)