From 8eea836e70de3f70bbc39f8fddf940187df8e93c Mon Sep 17 00:00:00 2001 From: Munkhtenger Munkh-Aldar Date: Mon, 27 May 2024 18:33:27 +0200 Subject: [PATCH] add more GNN models and fix --- configs/default_experiment.yaml | 42 +-- custom_components/attacks/base_attack.py | 76 ++--- .../attacks/global_attacks/greedy_rbcd.py | 303 ++++++++++++------ .../attacks/global_attacks/new_prbcd.py | 1 - custom_components/datasets/datasets.py | 44 ++- custom_components/models/AirGNN.py | 187 +++++++++++ custom_components/models/GAT.py | 88 +++++ custom_components/models/GCN.py | 88 +++++ custom_components/models/GPR.py | 132 ++++++++ custom_components/models/SAGE.py | 157 +++++++++ custom_components/models/architecture.py | 30 +- custom_components/models/base_model.py | 52 --- custom_components/models/model.py | 90 ------ custom_components/utils.py | 9 + gnn_toolbox/common.py | 22 +- .../experiment_handler/config_validator.py | 7 +- gnn_toolbox/experiment_handler/exp_runner.py | 2 +- main.py | 6 +- tests/conftest.py | 2 + tests/test_DICE_result_with_DPR.py | 68 ++++ tests/test_artifact_saving_retrieving.py | 65 +--- tests/test_bulk_data.py | 0 tests/test_configs/DICE_evasion.yaml | 29 ++ tests/test_configs/invalid_config.yaml | 18 +- ...r.py => test_experiment_yaml_validator.py} | 60 +--- ...tion.py => test_registration_component.py} | 0 26 files changed, 1127 insertions(+), 451 deletions(-) create mode 100644 custom_components/models/AirGNN.py create mode 100644 custom_components/models/GAT.py create mode 100644 custom_components/models/GCN.py create mode 100644 custom_components/models/GPR.py create mode 100644 custom_components/models/SAGE.py delete mode 100644 custom_components/models/base_model.py delete mode 100644 custom_components/models/model.py create mode 100644 tests/conftest.py create mode 100644 tests/test_DICE_result_with_DPR.py delete mode 100644 tests/test_bulk_data.py create mode 100644 tests/test_configs/DICE_evasion.yaml rename tests/{test_yaml_validator.py => test_experiment_yaml_validator.py} (84%) rename tests/{test_component_registration.py => test_registration_component.py} (100%) diff --git a/configs/default_experiment.yaml b/configs/default_experiment.yaml index a36ecdf..6f4eaa0 100644 --- a/configs/default_experiment.yaml +++ b/configs/default_experiment.yaml @@ -1,13 +1,13 @@ output_dir: ./output10 -# resume_output: False +resume_output: True # csv_save: True # esd: True experiment_templates: - name: Local attack_3 - seed: [0] + seed: [1] device: cuda model: - - name: GCN2 + - name: [GCN_DPR] params: hidden_channels: 64 dataset: @@ -15,8 +15,8 @@ experiment_templates: # name: [CoraFull, CS, Physics, Computers, Photo, CoraCitationFull, DBLP, PubMedCitationFull, Reddit] # ogbn-arxiv root: ./datasets make_undirected: true - params: - split: full + # params: + # split: full # train_ratio: 0.6 # test_ratio: 0.2 # val_ratio: 0.2 @@ -27,20 +27,20 @@ experiment_templates: # params: # value: 0.8 - attack: - # - scope: local - # name: [LocalDICE] - # type: poison - # epsilon: [0.5] - # nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - # - scope: global - # name: FGSM - # type: poison - # epsilon: [0.5] - - scope: global - name: DICE - type: poison - epsilon: [0.5] + # attack: + # # - scope: local + # # name: [LocalDICE] + # # type: poison + # # epsilon: [0.5] + # # nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + # # - scope: global + # # name: FGSM + # # type: poison + # # epsilon: [0.5] + # - scope: global + # name: DICE + # type: poison + # epsilon: [0.25] # - scope: local # name: [LocalDICE] @@ -64,8 +64,8 @@ experiment_templates: training: - max_epochs: 100 - patience: 70 + max_epochs: 300 + patience: 290 optimizer: name: adam diff --git a/custom_components/attacks/base_attack.py b/custom_components/attacks/base_attack.py index 8c1b204..8902222 100644 --- a/custom_components/attacks/base_attack.py +++ b/custom_components/attacks/base_attack.py @@ -311,41 +311,41 @@ def adj_adversary_for_poisoning(self): # from torch_geometric.datasets import Planetoid # from torch_geometric.transforms import ToSparseTensor # from torch_geometric.utils import dense_to_sparse -# class DenseAttack(BaseAttack): -# @typechecked -# def __init__( -# self, -# adj: Union[SparseTensor, TensorType], -# attr: TensorType, -# labels: TensorType, -# idx_attack: np.ndarray, -# model, -# device: Union[str, int, torch.device], -# make_undirected: bool = True, -# loss_type: str = "CE", -# **kwargs, -# ): -# if isinstance(adj, SparseTensor): -# adj = adj.to_dense() -# # adj = dense_to_sparse(adj) -# # cora = Planetoid(root='datasets', name='Cora',transform=ToSparseTensor(remove_edge_index=False)) -# # data = cora[0] -# # row, col, edge_attr = data.adj_t.t().coo() -# # edge_index = torch.stack([row, col], dim=0) -# # adj = edge_index -# # adj = data.adj_t.to_dense() -# # ad - -# super().__init__( -# adj, -# attr, -# labels, -# idx_attack, -# model, -# device, -# loss_type=loss_type, -# make_undirected=make_undirected, -# **kwargs, -# ) - -# self.n = adj.shape[0] +class DenseAttack(BaseAttack): + @typechecked + def __init__( + self, + adj: Union[SparseTensor, TensorType], + attr: TensorType, + labels: TensorType, + idx_attack: np.ndarray, + model, + device: Union[str, int, torch.device], + make_undirected: bool = True, + loss_type: str = "CE", + **kwargs, + ): + if isinstance(adj, SparseTensor): + adj = adj.to_dense() + # adj = dense_to_sparse(adj) + # cora = Planetoid(root='datasets', name='Cora',transform=ToSparseTensor(remove_edge_index=False)) + # data = cora[0] + # row, col, edge_attr = data.adj_t.t().coo() + # edge_index = torch.stack([row, col], dim=0) + # adj = edge_index + # adj = data.adj_t.to_dense() + # ad + + super().__init__( + adj, + attr, + labels, + idx_attack, + model, + device, + loss_type=loss_type, + make_undirected=make_undirected, + **kwargs, + ) + + self.n = adj.shape[0] diff --git a/custom_components/attacks/global_attacks/greedy_rbcd.py b/custom_components/attacks/global_attacks/greedy_rbcd.py index e2419a8..ec92a43 100644 --- a/custom_components/attacks/global_attacks/greedy_rbcd.py +++ b/custom_components/attacks/global_attacks/greedy_rbcd.py @@ -1,114 +1,225 @@ from tqdm import tqdm import torch -import torch_sparse -from torch_sparse import SparseTensor from custom_components import utils -from custom_components.attacks.global_attacks.prbcd import PRBCD +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from torch import Tensor +from torch_geometric.utils import coalesce, to_undirected +from custom_components.attacks.global_attacks.new_prbcd import PRBCDAttack from gnn_toolbox.registration_handler.register_components import register_global_attack -@register_global_attack("GreedyRBCD") -class GreedyRBCD(PRBCD): - """Sampled and hence scalable PGD attack for graph data. +LOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor] + +# @register_global_attack("GreedyRBCD") +# class GreedyRBCD(PRBCD): +# """Sampled and hence scalable PGD attack for graph data. +# """ + +# def __init__(self, epochs: int = 500, **kwargs): +# super().__init__(**kwargs) + +# rows, cols, self.edge_weight = self.adj.coo() +# self.edge_index = torch.stack([rows, cols], dim=0) + +# self.edge_index = self.edge_index.to(self.device) +# self.edge_weight = self.edge_weight.float().to(self.device) +# self.attr = self.attr.to(self.device) +# self.epochs = epochs + +# self.n_perturbations = 0 + +# def _greedy_update(self, step_size: int, gradient: torch.Tensor): +# _, topk_edge_index = torch.topk(gradient, step_size) + +# add_edge_index = self.modified_edge_index[:, topk_edge_index] +# add_edge_weight = torch.ones_like(add_edge_index[0], dtype=torch.float32) + +# if self.make_undirected: +# add_edge_index, add_edge_weight = utils.to_symmetric(add_edge_index, add_edge_weight, self.n) +# add_edge_index = torch.cat((self.edge_index, add_edge_index.to(self.device)), dim=-1) +# add_edge_weight = torch.cat((self.edge_weight, add_edge_weight.to(self.device))) +# edge_index, edge_weight = torch_sparse.coalesce( +# add_edge_index, add_edge_weight, m=self.n, n=self.n, op='sum' +# ) + +# is_one_mask = torch.isclose(edge_weight, torch.tensor(1.)) +# self.edge_index = edge_index[:, is_one_mask] +# self.edge_weight = edge_weight[is_one_mask] +# # self.edge_weight = torch.ones_like(self.edge_weight) +# assert self.edge_index.size(1) == self.edge_weight.size(0) + +# def attack(self, n_perturbations: int): +# """Perform attack + +# Parameters +# ---------- +# n_perturbations : int +# Number of edges to be perturbed (assuming an undirected graph) +# """ +# assert n_perturbations > self.n_perturbations, ( +# f'Number of perturbations must be bigger as this attack is greedy (current {n_perturbations}, ' +# f'previous {self.n_perturbations})' +# ) +# n_perturbations -= self.n_perturbations +# self.n_perturbations += n_perturbations + +# # To assert the number of perturbations later on +# clean_edges = self.edge_index.shape[1] + +# # Determine the number of edges to be flipped in each attach step / epoch +# step_size = n_perturbations // self.epochs +# if step_size > 0: +# steps = self.epochs * [step_size] +# for i in range(n_perturbations % self.epochs): +# steps[i] += 1 +# else: +# steps = [1] * n_perturbations + +# for step_size in tqdm(steps): +# # Sample initial search space (Algorithm 2, line 3-4) +# self.sample_random_block(step_size) +# # Retreive sparse perturbed adjacency matrix `A \oplus p_{t-1}` (Algorithm 2, line 7) +# edge_index, edge_weight = self.get_modified_adj() + +# if torch.cuda.is_available() and self.do_synchronize: +# torch.cuda.empty_cache() +# torch.cuda.synchronize() + +# # Calculate logits for each node (Algorithm 2, line 7) +# logits = self._get_logits(self.attr, edge_index, edge_weight) +# # Calculate loss combining all each node (Algorithm 2, line 8) +# loss = self.calculate_loss(logits[self.idx_attack], self.labels[self.idx_attack]) +# # Retreive gradient towards the current block (Algorithm 2, line 8) +# gradient = utils.grad_with_checkpoint(loss, self.perturbed_edge_weight)[0] + +# if torch.cuda.is_available() and self.do_synchronize: +# torch.cuda.empty_cache() +# torch.cuda.synchronize() + +# with torch.no_grad(): +# # Greedy update of edges (Algorithm 2, line 8) +# self._greedy_update(step_size, gradient) + +# del logits +# del loss +# del gradient + +# allowed_perturbations = 2 * n_perturbations if self.make_undirected else n_perturbations +# edges_after_attack = self.edge_index.shape[1] +# assert (edges_after_attack >= clean_edges - allowed_perturbations +# and edges_after_attack <= clean_edges + allowed_perturbations), \ +# f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations' + +# self.adj_adversary = SparseTensor.from_edge_index( +# self.edge_index, self.edge_weight, (self.n, self.n) +# ).coalesce().detach() + +# self.attr_adversary = self.attr + +# @register_global_attack("GRBCD") +class GRBCDAttack(PRBCDAttack): + r"""The Greedy Randomized Block Coordinate Descent (GRBCD) adversarial + attack from the `Robustness of Graph Neural Networks at Scale + `_ paper. + + GRBCD shares most of the properties and requirements with + :class:`PRBCDAttack`. It also uses an efficient gradient based approach. + However, it greedily flips edges based on the gradient towards the + adjacency matrix. + + Args: + model (torch.nn.Module): The GNN module to assess. + block_size (int): Number of randomly selected elements in the + adjacency matrix to consider. + epochs (int, optional): Number of epochs (aborts early if + :obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`) + loss (str or callable, optional): A loss to quantify the "strength" of + an attack. Note that this function must match the output format of + :attr:`model`. By default, it is assumed that the task is + classification and that the model returns raw predictions (*i.e.*, + no output activation) or uses :obj:`logsoftmax`. Moreover, and the + number of predictions should match the number of labels passed to + :attr:`attack`. Either pass Callable or one of: :obj:`'masked'`, + :obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`. + (default: :obj:`'masked'`) + is_undirected (bool, optional): If :obj:`True` the graph is + assumed to be undirected. (default: :obj:`True`) + log (bool, optional): If set to :obj:`False`, will not log any learning + progress. (default: :obj:`True`) """ + coeffs = {'max_trials_sampling': 20, 'eps': 1e-7} + + def __init__( + self, + block_size: int = 200_000, + epochs: int = 125, + loss: Optional[Union[str, LOSS_TYPE]] = 'masked', + make_undirected: bool = True, + log: bool = True, + **kwargs, + ): + super().__init__(block_size, epochs, loss_type=loss, + make_undirected=make_undirected, log=log, **kwargs) + + @torch.no_grad() + def _prepare(self, budget: int) -> List[int]: + """Prepare attack.""" + self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device) + + # Determine the number of edges to be flipped in each attach step/epoch + step_size = budget // self.epochs + if step_size > 0: + steps = self.epochs * [step_size] + for i in range(budget % self.epochs): + steps[i] += 1 + else: + steps = [1] * budget - def __init__(self, epochs: int = 500, **kwargs): - super().__init__(**kwargs) - - rows, cols, self.edge_weight = self.adj.coo() - self.edge_index = torch.stack([rows, cols], dim=0) - - self.edge_index = self.edge_index.to(self.device) - self.edge_weight = self.edge_weight.float().to(self.device) - self.attr = self.attr.to(self.device) - self.epochs = epochs + # Sample initial search space (Algorithm 2, line 3-4) + self._sample_random_block(step_size) - self.n_perturbations = 0 + return steps - def _greedy_update(self, step_size: int, gradient: torch.Tensor): + @torch.no_grad() + def _update(self, step_size: int, gradient: Tensor, *args, + **kwargs) -> Dict[str, Any]: + """Update edge weights given gradient.""" _, topk_edge_index = torch.topk(gradient, step_size) - add_edge_index = self.modified_edge_index[:, topk_edge_index] - add_edge_weight = torch.ones_like(add_edge_index[0], dtype=torch.float32) - - if self.make_undirected: - add_edge_index, add_edge_weight = utils.to_symmetric(add_edge_index, add_edge_weight, self.n) - add_edge_index = torch.cat((self.edge_index, add_edge_index.to(self.device)), dim=-1) - add_edge_weight = torch.cat((self.edge_weight, add_edge_weight.to(self.device))) - edge_index, edge_weight = torch_sparse.coalesce( - add_edge_index, add_edge_weight, m=self.n, n=self.n, op='sum' - ) + flip_edge_index = self.block_edge_index[:, topk_edge_index] + flip_edge_weight = torch.ones_like(flip_edge_index[0], + dtype=torch.float32) + + self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index), + axis=-1) + + if self.is_undirected: + flip_edge_index, flip_edge_weight = to_undirected( + flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes, + reduce='mean') + edge_index = torch.cat( + (self.edge_index.to(self.device), flip_edge_index.to(self.device)), + dim=-1) + edge_weight = torch.cat((self.edge_weight.to(self.device), + flip_edge_weight.to(self.device))) + edge_index, edge_weight = coalesce(edge_index, edge_weight, + num_nodes=self.num_nodes, + reduce='sum') is_one_mask = torch.isclose(edge_weight, torch.tensor(1.)) self.edge_index = edge_index[:, is_one_mask] self.edge_weight = edge_weight[is_one_mask] - # self.edge_weight = torch.ones_like(self.edge_weight) assert self.edge_index.size(1) == self.edge_weight.size(0) - def attack(self, n_perturbations: int): - """Perform attack - - Parameters - ---------- - n_perturbations : int - Number of edges to be perturbed (assuming an undirected graph) - """ - assert n_perturbations > self.n_perturbations, ( - f'Number of perturbations must be bigger as this attack is greedy (current {n_perturbations}, ' - f'previous {self.n_perturbations})' - ) - n_perturbations -= self.n_perturbations - self.n_perturbations += n_perturbations - - # To assert the number of perturbations later on - clean_edges = self.edge_index.shape[1] - - # Determine the number of edges to be flipped in each attach step / epoch - step_size = n_perturbations // self.epochs - if step_size > 0: - steps = self.epochs * [step_size] - for i in range(n_perturbations % self.epochs): - steps[i] += 1 - else: - steps = [1] * n_perturbations - - for step_size in tqdm(steps): - # Sample initial search space (Algorithm 2, line 3-4) - self.sample_random_block(step_size) - # Retreive sparse perturbed adjacency matrix `A \oplus p_{t-1}` (Algorithm 2, line 7) - edge_index, edge_weight = self.get_modified_adj() - - if torch.cuda.is_available() and self.do_synchronize: - torch.cuda.empty_cache() - torch.cuda.synchronize() - - # Calculate logits for each node (Algorithm 2, line 7) - logits = self._get_logits(self.attr, edge_index, edge_weight) - # Calculate loss combining all each node (Algorithm 2, line 8) - loss = self.calculate_loss(logits[self.idx_attack], self.labels[self.idx_attack]) - # Retreive gradient towards the current block (Algorithm 2, line 8) - gradient = utils.grad_with_checkpoint(loss, self.perturbed_edge_weight)[0] - - if torch.cuda.is_available() and self.do_synchronize: - torch.cuda.empty_cache() - torch.cuda.synchronize() - - with torch.no_grad(): - # Greedy update of edges (Algorithm 2, line 8) - self._greedy_update(step_size, gradient) - - del logits - del loss - del gradient - - allowed_perturbations = 2 * n_perturbations if self.make_undirected else n_perturbations - edges_after_attack = self.edge_index.shape[1] - assert (edges_after_attack >= clean_edges - allowed_perturbations - and edges_after_attack <= clean_edges + allowed_perturbations), \ - f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations' - - self.adj_adversary = SparseTensor.from_edge_index( - self.edge_index, self.edge_weight, (self.n, self.n) - ).coalesce().detach() - - self.attr_adversary = self.attr + # Sample initial search space (Algorithm 2, line 3-4) + self._sample_random_block(step_size) + + # Return debug information + scalars = { + 'number_positive_entries_in_gradient': (gradient > 0).sum().item() + } + return scalars + + def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + """Clean up and prepare return argument.""" + return self.edge_index, self.flipped_edges \ No newline at end of file diff --git a/custom_components/attacks/global_attacks/new_prbcd.py b/custom_components/attacks/global_attacks/new_prbcd.py index 2e28d6a..d9566d5 100644 --- a/custom_components/attacks/global_attacks/new_prbcd.py +++ b/custom_components/attacks/global_attacks/new_prbcd.py @@ -656,7 +656,6 @@ def _append_statistics(self, mapping: Dict[str, Any]): def __repr__(self) -> str: return f'{self.__class__.__name__}()' - @register_global_attack("GRBCD") class GRBCDAttack(PRBCDAttack): r"""The Greedy Randomized Block Coordinate Descent (GRBCD) adversarial diff --git a/custom_components/datasets/datasets.py b/custom_components/datasets/datasets.py index 90cbf17..90c03bf 100644 --- a/custom_components/datasets/datasets.py +++ b/custom_components/datasets/datasets.py @@ -5,26 +5,40 @@ Coauthor, CitationFull, Reddit, - GNNBenchmarkDataset, + # GNNBenchmarkDataset, ) from ogb.nodeproppred import PygNodePropPredDataset from gnn_toolbox.registration_handler.register_components import register_dataset +# Registering datasets: register_dataset(name, dataset) -register_dataset('Cora', lambda root, transform=None, **kwargs: Planetoid(root, name='Cora', transform=transform, **kwargs)) -register_dataset('Citeseer', lambda root: Planetoid(root, name='Citeseer')) -register_dataset('PubMed', lambda root: Planetoid(root, name='PubMed')) -register_dataset('CoraFull', CoraFull) -register_dataset('CS', lambda root: Coauthor(root, name='CS')) -register_dataset('Physics', lambda root: Coauthor(root, name='Physics')) -register_dataset('Computers', lambda root: Amazon(root, name='Computers')) -register_dataset('Photo', lambda root: Amazon(root, name='Photo')) -register_dataset('CoraCitationFull', lambda root: CitationFull(root, name='Cora')) -register_dataset('DBLP', lambda root: CitationFull(root, name='Cora')) -register_dataset('PubMedCitationFull', lambda root: CitationFull(root, name='PubMed')) -register_dataset('Reddit', Reddit) -register_dataset('GNNBenchmarkDataset', GNNBenchmarkDataset) -register_dataset('ogbn-arxiv', PygNodePropPredDataset) +register_dataset('Cora', lambda root, transform, **kwargs: Planetoid(root, name='Cora', transform=transform, **kwargs)) + +register_dataset('Citeseer', lambda root, transform, **kwargs: Planetoid(root, name='Citeseer', transform=transform, **kwargs)) + +register_dataset('PubMed', lambda root, transform, **kwargs: Planetoid(root, name='PubMed', transform=transform, **kwargs)) + +register_dataset('CoraFull', lambda root, transform, **kwargs: CoraFull(root, transform=transform, **kwargs)) + +register_dataset('CS', lambda root, transform, **kwargs: Coauthor(root, name='CS', transform=transform, **kwargs)) + +register_dataset('Physics', lambda root, transform, **kwargs: Coauthor(root, name='Physics', transform=transform, **kwargs)) + +register_dataset('Computers', lambda root, transform, **kwargs: Amazon(root, name='Computers', transform=transform, **kwargs)) + +register_dataset('Photo', lambda root, transform, **kwargs: Amazon(root, name='Photo', transform=transform, **kwargs)) + +register_dataset('CoraCitationFull', lambda root, transform, **kwargs: CitationFull(root, name='Cora', transform=transform, **kwargs)) + +register_dataset('DBLP', lambda root, transform, **kwargs: CitationFull(root, name='Cora', transform=transform, **kwargs)) + +register_dataset('PubMedCitationFull', lambda root, transform, **kwargs: CitationFull(root, name='PubMed', transform=transform, **kwargs)) + +register_dataset('Reddit', lambda root, transform, **kwargs: Reddit(root, transform=transform, **kwargs)) + +# register_dataset('GNNBenchmarkDataset', lambda root, transform, **kwargs: GNNBenchmarkDataset(root, transform=transform, **kwargs)) + +register_dataset('PygNodePropPredDataset', lambda root, transform, **kwargs: PygNodePropPredDataset(root, name = 'ogbn-arxiv', transform=transform, **kwargs)) diff --git a/custom_components/models/AirGNN.py b/custom_components/models/AirGNN.py new file mode 100644 index 0000000..7526bb1 --- /dev/null +++ b/custom_components/models/AirGNN.py @@ -0,0 +1,187 @@ +import logging +import torch +import torch.nn.functional as F +from torch.nn import Linear +from torch_geometric.nn.conv.gcn_conv import gcn_norm +from torch_geometric.nn.conv import MessagePassing +from typing import Optional, Tuple +from torch_geometric.typing import Adj, OptTensor +from torch import Tensor +from torch_sparse import SparseTensor, matmul +import torch.nn as nn +from custom_components.utils import ensure_contiguousness +from gnn_toolbox.registration_handler.register_components import register_model + +@register_model("AirGNN") +class AirGNN(nn.Module): + + def __init__(self, in_channels, hidden_channels, out_channels, nlayers=2, K=2, dropout=0.5, lr=0.01, + with_bn=False, weight_decay=5e-4, model = 'AirGNN', lambda_amp = 0, alpha=0.1): + + super(AirGNN, self).__init__() + + self.lins = nn.ModuleList([]) + self.lins.append(Linear(in_channels, hidden_channels)) + if with_bn: + self.bns = nn.ModuleList([]) + self.bns.append(nn.BatchNorm1d(hidden_channels)) + for i in range(nlayers-2): + self.lins.append(Linear(hidden_channels, hidden_channels)) + if with_bn: + self.bns.append(nn.BatchNorm1d(hidden_channels)) + self.lins.append(Linear(hidden_channels, out_channels)) + + self.prop = AdaptiveMessagePassing(K=K, alpha=alpha, mode=model, lambda_amp = lambda_amp) + logging.info(self.prop) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.name = model + self.with_bn = with_bn + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + for lin in self.lins: + lin.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() + self.prop.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = ensure_contiguousness(x, edge_index, edge_weight) + edge_index = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=2 * x.shape[:1]).t() + for ii, lin in enumerate(self.lins[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = lin(x) + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lins[-1](x) + x = self.prop(x, edge_index) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = ensure_contiguousness(x, edge_index, edge_weight) + edge_index = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=2 * x.shape[:1]).t() + for ii, lin in enumerate(self.lins[:-1]): + x = lin(x) + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = self.prop(x, edge_index) + return x + + +class AdaptiveMessagePassing(MessagePassing): + _cached_edge_index: Optional[Tuple[Tensor, Tensor]] + _cached_adj_t: Optional[SparseTensor] + + def __init__(self, + K: int, + alpha: float, + dropout: float = 0., + cached: bool = False, + add_self_loops: bool = True, + normalize: bool = False, + mode: str = None, + node_num: int = None, + lambda_amp: float = 0.5, + **kwargs): + + super(AdaptiveMessagePassing, self).__init__(aggr='add', **kwargs) + self.K = K + self.alpha = alpha + self.mode = mode + self.dropout = dropout + self.cached = cached + self.add_self_loops = add_self_loops + self.normalize = normalize + self._cached_edge_index = None + self.node_num = node_num + self.lambda_amp = lambda_amp + self._cached_adj_t = None + + def reset_parameters(self): + self._cached_edge_index = None + self._cached_adj_t = None + + def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mode=None) -> Tensor: + if self.normalize: + if isinstance(edge_index, Tensor): + raise ValueError('Only support SparseTensor now') + + elif isinstance(edge_index, SparseTensor): + cache = self._cached_adj_t + if cache is None: + edge_index = gcn_norm( # yapf: disable + edge_index, edge_weight, x.size(self.node_dim), False, + add_self_loops=self.add_self_loops, dtype=x.dtype) + if self.cached: + self._cached_adj_t = edge_index + else: + edge_index = cache + + if mode == None: mode = self.mode + + if self.K <= 0: + return x + hh = x + + if mode == 'MLP': + return x + + elif mode == 'APPNP': + x = self.appnp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K, alpha=self.alpha) + + elif mode in ['AirGNN']: + x = self.amp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K) + else: + raise ValueError('wrong propagate mode') + return x + + def appnp_forward(self, x, hh, edge_index, K, alpha): + for k in range(K): + x = self.propagate(edge_index, x=x, edge_weight=None, size=None) + x = x * (1 - alpha) + x += alpha * hh + return x + + def amp_forward(self, x, hh, K, edge_index): + lambda_amp = self.lambda_amp + gamma = 1 / (2 * (1 - lambda_amp)) ## or simply gamma = 1 + + for k in range(K): + y = x - gamma * 2 * (1 - lambda_amp) * self.compute_LX(x=x, edge_index=edge_index) # Equation (9) + x = hh + self.proximal_L21(x=y - hh, lambda_=gamma * lambda_amp) # Equation (11) and (12) + return x + + def proximal_L21(self, x: Tensor, lambda_): + row_norm = torch.norm(x, p=2, dim=1) + score = torch.clamp(row_norm - lambda_, min=0) + index = torch.where(row_norm > 0) # Deal with the case when the row_norm is 0 + score[index] = score[index] / row_norm[index] # score is the adaptive score in Equation (14) + return score.unsqueeze(1) * x + + def compute_LX(self, x, edge_index, edge_weight=None): + x = x - self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) + return x + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: + return matmul(adj_t, x, reduce=self.aggr) + + def __repr__(self): + return '{}(K={}, alpha={}, mode={}, dropout={}, lambda_amp={})'.format(self.__class__.__name__, self.K, + self.alpha, self.mode, self.dropout, + self.lambda_amp) + + diff --git a/custom_components/models/GAT.py b/custom_components/models/GAT.py new file mode 100644 index 0000000..fee6f0d --- /dev/null +++ b/custom_components/models/GAT.py @@ -0,0 +1,88 @@ +""" +Adapted from DeepRobust project: https://github.com/DSE-MSU/DeepRobust/blob/master/deeprobust/graph/defense_pyg/gat.py +""" + + +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import GATConv +# from .mygat_conv import GATConv +# from .base_model import BaseModel + +from gnn_toolbox.registration_handler.register_components import register_model + +@register_model("GAT_DPR") +class GAT(nn.Module): + + def __init__(self, in_channels, hidden_channels, out_channels, heads=8, output_heads=1, dropout=0.5, lr=0.01, + nlayers=2, with_bn=False, weight_decay=5e-4, with_bias=True): + + super(GAT, self).__init__() + + # assert device is not None, "Please specify 'device'!" + + self.convs = nn.ModuleList([]) + if with_bn: + self.bns = nn.ModuleList([]) + self.bns.append(nn.BatchNorm1d(hidden_channels*heads)) + + self.convs.append(GATConv( + in_channels, + hidden_channels, + heads=heads, + dropout=dropout, + bias=with_bias)) + + for i in range(nlayers-2): + self.convs.append(GATConv(hidden_channels*heads, + hidden_channels, heads=heads, dropout=dropout, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(hidden_channels*heads)) + + self.convs.append(GATConv( + hidden_channels * heads, + out_channels, + heads=output_heads, + concat=False, + dropout=dropout, + bias=with_bias)) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.output = None + self.best_model = None + self.best_output = None + self.name = 'GAT' + self.with_bn = with_bn + + def forward(self, x, edge_index, edge_weight=None): + for ii, conv in enumerate(self.convs[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[ii](x) + x = F.elu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.convs[-1](x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + for ii, conv in enumerate(self.convs[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[ii](x) + x = F.elu(x) + return x + + def initialize(self): + for conv in self.convs: + conv.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() \ No newline at end of file diff --git a/custom_components/models/GCN.py b/custom_components/models/GCN.py new file mode 100644 index 0000000..45835f0 --- /dev/null +++ b/custom_components/models/GCN.py @@ -0,0 +1,88 @@ +""" +Adapted from DeepRobust project: https://github.com/DSE-MSU/DeepRobust/blob/master/deeprobust/graph/defense_pyg/gcn.py +""" + + +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import GCNConv +# from .base_model import BaseModel +from torch_sparse import coalesce, SparseTensor, matmul +from custom_components.utils import ensure_contiguousness + +from gnn_toolbox.registration_handler.register_components import register_model + +@register_model("GCN_DPR") +class GCN(nn.Module): + + def __init__(self, in_channels, hidden_channels, out_channels, nlayers=2, dropout=0.5, lr=0.01, + with_bn=False, weight_decay=5e-4, with_bias=True): + + super(GCN, self).__init__() + + self.layers = nn.ModuleList([]) + if with_bn: + self.bns = nn.ModuleList() + + if nlayers == 1: + self.layers.append(GCNConv(in_channels, out_channels, bias=with_bias)) + else: + self.layers.append(GCNConv(in_channels, hidden_channels, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(hidden_channels)) + for i in range(nlayers-2): + self.layers.append(GCNConv(hidden_channels, hidden_channels, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(hidden_channels)) + self.layers.append(GCNConv(hidden_channels, out_channels, bias=with_bias)) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.output = None + self.best_model = None + self.best_output = None + self.with_bn = with_bn + self.name = 'GCN' + + def forward(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = ensure_contiguousness(x, edge_index, edge_weight) + for ii, layer in enumerate(self.layers): + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + x = layer(x, adj) + else: + x = layer(x, edge_index) + if ii != len(self.layers) - 1: + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = ensure_contiguousness(x, edge_index, edge_weight) + for ii, layer in enumerate(self.layers): + if ii == len(self.layers) - 1: + return x + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + x = layer(x, adj) + else: + x = layer(x, edge_index) + if ii != len(self.layers) - 1: + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + return x + + def initialize(self): + for m in self.layers: + m.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() diff --git a/custom_components/models/GPR.py b/custom_components/models/GPR.py new file mode 100644 index 0000000..126422e --- /dev/null +++ b/custom_components/models/GPR.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +from torch_geometric.nn import GCNConv, SAGEConv, GATConv, APPNP, MessagePassing +from torch_geometric.nn.conv.gcn_conv import gcn_norm +import scipy.sparse +import numpy as np +from gnn_toolbox.registration_handler.register_components import register_model + +@register_model("GPRGNN") +class GPRGNN(nn.Module): + """GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN""" + + def __init__(self, in_channels, hidden_channels, out_channels, Init='PPR', dprate=.5, dropout=.5, + lr=0.01, weight_decay=0, + K=10, alpha=.1, Gamma=None, ppnp='GPR_prop'): + super(GPRGNN, self).__init__() + self.lin1 = nn.Linear(in_channels, hidden_channels) + self.lin2 = nn.Linear(hidden_channels, out_channels) + + if ppnp == 'PPNP': + self.prop1 = APPNP(K, alpha) + elif ppnp == 'GPR_prop': + self.prop1 = GPR_prop(K, alpha, Init, Gamma) + + self.Init = Init + self.dprate = dprate + self.dropout = dropout + self.name = "GPR" + self.weight_decay = weight_decay + self.lr = lr + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + self.lin1.reset_parameters() + self.lin2.reset_parameters() + self.prop1.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + + x = F.dropout(x, p=self.dropout, training=self.training) + x = F.relu(self.lin1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lin2(x) + + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]) + if self.dprate == 0.0: + x = self.prop1(x, adj) + else: + x = F.dropout(x, p=self.dprate, training=self.training) + x = self.prop1(x, adj) + else: + if self.dprate == 0.0: + x = self.prop1(x, edge_index, edge_weight) + else: + x = F.dropout(x, p=self.dprate, training=self.training) + x = self.prop1(x, edge_index, edge_weight) + + return F.log_softmax(x, dim=1) + + +class GPR_prop(MessagePassing): + ''' + GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN + propagation class for GPR_GNN + ''' + + def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs): + super(GPR_prop, self).__init__(aggr='add', **kwargs) + self.K = K + self.Init = Init + self.alpha = alpha + + assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] + if Init == 'SGC': + # SGC-like + TEMP = 0.0*np.ones(K+1) + TEMP[alpha] = 1.0 + elif Init == 'PPR': + # PPR-like + TEMP = alpha*(1-alpha)**np.arange(K+1) + TEMP[-1] = (1-alpha)**K + elif Init == 'NPPR': + # Negative PPR + TEMP = (alpha)**np.arange(K+1) + TEMP = TEMP/np.sum(np.abs(TEMP)) + elif Init == 'Random': + # Random + bound = np.sqrt(3/(K+1)) + TEMP = np.random.uniform(-bound, bound, K+1) + TEMP = TEMP/np.sum(np.abs(TEMP)) + elif Init == 'WS': + # Specify Gamma + TEMP = Gamma + + self.temp = nn.Parameter(torch.tensor(TEMP)) + + def reset_parameters(self): + nn.init.zeros_(self.temp) + for k in range(self.K+1): + self.temp.data[k] = self.alpha*(1-self.alpha)**k + self.temp.data[-1] = (1-self.alpha)**self.K + + def forward(self, x, edge_index, edge_weight=None): + if isinstance(edge_index, torch.Tensor): + edge_index, norm = gcn_norm( + edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) + elif isinstance(edge_index, SparseTensor): + edge_index = gcn_norm( + edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) + norm = None + + hidden = x*(self.temp[0]) + for k in range(self.K): + x = self.propagate(edge_index, x=x, norm=norm) + gamma = self.temp[k+1] + hidden = hidden + gamma*x + return hidden + + def message(self, x_j, norm): + return norm.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t, x): + return matmul(adj_t, x, reduce=self.aggr) + + def __repr__(self): + return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, + self.temp) diff --git a/custom_components/models/SAGE.py b/custom_components/models/SAGE.py new file mode 100644 index 0000000..b0ae90a --- /dev/null +++ b/custom_components/models/SAGE.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +# from torch_geometric.nn import SAGEConv, GATConv, APPNP, MessagePassing +from torch_geometric.nn.conv.gcn_conv import gcn_norm +import scipy.sparse +import numpy as np +# from .base_model import BaseModel +from gnn_toolbox.registration_handler.register_components import register_model + +@register_model("SAGE") +class SAGE(nn.Module): + + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, + dropout=0.5, lr=0.01, weight_decay=0, with_bn=False, **kwargs): + super(SAGE, self).__init__() + + self.convs = nn.ModuleList() + self.convs.append( + SAGEConv(in_channels, hidden_channels)) + + self.bns = nn.ModuleList() + if 'nlayers' in kwargs: + num_layers = kwargs['nlayers'] + self.bns.append(nn.BatchNorm1d(hidden_channels)) + for _ in range(num_layers - 2): + self.convs.append( + SAGEConv(hidden_channels, hidden_channels)) + self.bns.append(nn.BatchNorm1d(hidden_channels)) + + self.convs.append( + SAGEConv(hidden_channels, out_channels)) + + self.weight_decay = weight_decay + self.lr = lr + self.dropout = dropout + self.activation = F.relu + self.with_bn = with_bn + self.name = "SAGE" + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + + def forward(self, x, edge_index, edge_weight=None): + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + + for i, conv in enumerate(self.convs[:-1]): + if edge_weight is not None: + x = conv(x, adj) + else: + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + if edge_weight is not None: + x = self.convs[-1](x, adj) + else: + x = self.convs[-1](x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + + + +from typing import Union, Tuple +from torch_geometric.typing import OptPairTensor, Adj, Size + +from torch import Tensor +from torch.nn import Linear +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +from torch_geometric.nn.conv import MessagePassing + + +class SAGEConv(MessagePassing): + r"""The GraphSAGE operator from the `"Inductive Representation Learning on + Large Graphs" `_ paper + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot + \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j + + Args: + in_channels (int or tuple): Size of each input sample. A tuple + corresponds to the sizes of source and target dimensionalities. + out_channels (int): Size of each output sample. + normalize (bool, optional): If set to :obj:`True`, output features + will be :math:`\ell_2`-normalized, *i.e.*, + :math:`\frac{\mathbf{x}^{\prime}_i} + {\| \mathbf{x}^{\prime}_i \|_2}`. + (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.nn.conv.MessagePassing`. + """ + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, normalize: bool = False, + bias: bool = True, **kwargs): # yapf: disable + kwargs.setdefault('aggr', 'mean') + super(SAGEConv, self).__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_l = Linear(in_channels[0], out_channels, bias=bias) + self.lin_r = Linear(in_channels[1], out_channels, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + self.lin_r.reset_parameters() + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None) -> Tensor: + """""" + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=x, size=size) + out = self.lin_l(out) + + x_r = x[1] + if x_r is not None: + out += self.lin_r(x_r) + + if self.normalize: + out = F.normalize(out, p=2., dim=-1) + + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def message_and_aggregate(self, adj_t: SparseTensor, + x: OptPairTensor) -> Tensor: + # Deleted the following line to make propagation differentiable + # adj_t = adj_t.set_value(None, layout=None) + return matmul(adj_t, x[0], reduce=self.aggr) + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) diff --git a/custom_components/models/architecture.py b/custom_components/models/architecture.py index 5f243d0..f2bfb81 100644 --- a/custom_components/models/architecture.py +++ b/custom_components/models/architecture.py @@ -68,7 +68,35 @@ def forward(self, x, edge_index, edge_weight=None, **kwargs): # layers = list(self.modules()) # for idx, m in enumerate(self.modules()): # print(idx, '->', m) - + +@register_model('GAT') +class GAT1(): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2): + super().__init__() + self.norm = gcn_norm + self.conv1 = GATConv(in_channels, hidden_channels) + self.conv2 = GATConv(hidden_channels, out_channels) + + def reset_parameters(self): + self.conv1.reset_parameters() + self.conv2.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None, **kwargs): + # Normalize edge indices only once: + # if not kwargs.get('skip_norm', False): + # edge_index, edge_weight = self.norm( + # edge_index, + # edge_weight, + # num_nodes=x.size(0), + # add_self_loops=True, + # ) + + x = self.conv1(x, edge_index, edge_weight).relu() + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index, edge_weight) + return x + + # model = GCN(16, 16, 16) # print(model.hparams) diff --git a/custom_components/models/base_model.py b/custom_components/models/base_model.py deleted file mode 100644 index e34330a..0000000 --- a/custom_components/models/base_model.py +++ /dev/null @@ -1,52 +0,0 @@ -# class BaseModel(): -# @ staticmethod -# def parse_forward_input(data: Optional[Union[Data, TensorType["n_nodes", "n_features"]]] = None, -# adj: Optional[Union[SparseTensor, -# torch.sparse.FloatTensor, -# Tuple[TensorType[2, "nnz"], TensorType["nnz"]], -# TensorType["n_nodes", "n_nodes"]]] = None, -# attr_idx: Optional[TensorType["n_nodes", "n_features"]] = None, -# edge_idx: Optional[TensorType[2, "nnz"]] = None, -# edge_weight: Optional[TensorType["nnz"]] = None, -# n: Optional[int] = None, -# d: Optional[int] = None) -> Tuple[TensorType["n_nodes", "n_features"], -# TensorType[2, "nnz"], -# TensorType["nnz"]]: -# edge_weight = None -# # PyTorch Geometric support -# if isinstance(data, Data): -# x, edge_idx = data.x, data.edge_index -# # Randomized smoothing support -# elif attr_idx is not None and edge_idx is not None and n is not None and d is not None: -# x = coalesce(attr_idx, torch.ones_like(attr_idx[0], dtype=torch.float32), m=n, n=d) -# x = torch.sparse.FloatTensor(x[0], x[1], torch.Size([n, d])).to_dense() -# edge_idx = edge_idx -# # Empirical robustness support -# elif isinstance(adj, tuple): -# # Necessary since `torch.sparse.FloatTensor` eliminates the gradient... -# x, edge_idx, edge_weight = data, adj[0], adj[1] -# elif isinstance(adj, SparseTensor): -# x = data - -# edge_idx_rows, edge_idx_cols, edge_weight = adj.coo() -# edge_idx = torch.stack([edge_idx_rows, edge_idx_cols], dim=0) -# else: -# if not adj.is_sparse: -# adj = adj.to_sparse() - -# x, edge_idx, edge_weight = data, adj._indices(), adj._values() - -# if edge_weight is None: -# edge_weight = torch.ones_like(edge_idx[0], dtype=torch.float32) - -# if edge_weight.dtype != torch.float32: -# edge_weight = edge_weight.float() - -# return x, edge_idx, edge_weight - - - - -# model(x, edge_idx, edge_weight) - -# depending on the attack, \ No newline at end of file diff --git a/custom_components/models/model.py b/custom_components/models/model.py deleted file mode 100644 index e416106..0000000 --- a/custom_components/models/model.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, Dict, Tuple -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import GCNConv -from pytorch_lightning import LightningModule - -# from gnn_toolbox.old.config_def import cfg -# from custom_components.optimizers.optimizers import register_optimizer -from gnn_toolbox.registration_handler.registry import registry - -class BaseModel(LightningModule): - def __init__(self, config, **kwargs): - super().__init__() - self.config = config - self.save_hyperparameters() - - # self.train_acc = Accuracy(task=cfg., num_classes=out_channels) - # self.val_acc = Accuracy(task='multiclass', num_classes=out_channels) - # self.test_acc = Accuracy(task='multiclass', num_classes=out_channels) - - # self.model = registry['architecture'][cfg.model.name](**cfg.model.params) - - def forward(self, *args, **kwargs): - raise NotImplementedError - # return self.model(*args, **kwargs) - - - def configure_optimizers(self): - # Configure optimizer using the local config - optimizer_name = self.config['optimizer']['name'] - optimizer_params = self.config['optimizer']['params'] - optimizer_class = registry['optimizer'][optimizer_name] - return optimizer_class(self.parameters(), **optimizer_params) - - # def compute_loss(self, pred, true): - # raise NotImplementedError - - def _shared_step(self, batch, split: str) -> Dict: - batch.split = split - pred, true = self(batch) - loss, pred_score = compute_loss(pred, true) - step_end_time = time.time() - self.log(split, loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - return dict(loss=loss, true=true, pred_score=pred_score.detach(), - step_end_time=step_end_time) - - def training_step(self, batch, *args, **kwargs): - return self._shared_step(batch, split="train") - - def validation_step(self, batch, *args, **kwargs): - return self._shared_step(batch, split="val") - - def test_step(self, batch, *args, **kwargs): - return self._shared_step(batch, split="test") - - -# def compute_loss(pred, true): -# """Compute loss and prediction score. - -# Args: -# pred (torch.tensor): Unnormalized prediction -# true (torch.tensor): Grou - -# Returns: Loss, normalized prediction score - -# """ -# bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average) -# mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average) - -# # default manipulation for pred and true -# # can be skipped if special loss computation is needed -# pred = pred.squeeze(-1) if pred.ndim > 1 else pred -# true = true.squeeze(-1) if true.ndim > 1 else true - -# if cfg.model.loss_fun == 'cross_entropy': -# # multiclass -# if pred.ndim > 1 and true.ndim == 1: -# pred = F.log_softmax(pred, dim=-1) -# return F.nll_loss(pred, true), pred -# # binary or multilabel -# else: -# true = true.float() -# return bce_loss(pred, true), torch.sigmoid(pred) -# elif cfg.model.loss_fun == 'mse': -# true = true.float() -# return mse_loss(pred, true), pred -# else: -# raise ValueError(f"Loss function '{cfg.model.loss_fun}' not supported") \ No newline at end of file diff --git a/custom_components/utils.py b/custom_components/utils.py index 7f56222..d800df4 100644 --- a/custom_components/utils.py +++ b/custom_components/utils.py @@ -133,3 +133,12 @@ def accuracy(logits: torch.Tensor, labels: torch.Tensor, split_idx: np.ndarray) the Accuracy """ return (logits.argmax(1)[split_idx] == labels[split_idx]).float().mean().item() + +def ensure_contiguousness(x, edge_idx, edge_weight): + if not x.is_sparse: + x = x.contiguous() + if hasattr(edge_idx, 'contiguous'): + edge_idx = edge_idx.contiguous() + if edge_weight is not None: + edge_weight = edge_weight.contiguous() + return x, edge_idx, edge_weight \ No newline at end of file diff --git a/gnn_toolbox/common.py b/gnn_toolbox/common.py index f3c800d..cb2cbe8 100644 --- a/gnn_toolbox/common.py +++ b/gnn_toolbox/common.py @@ -370,8 +370,9 @@ def prepare_dataset( edge_weight = data.edge_weight else: edge_weight = torch.ones(edge_index.shape[1]) - + # edge_weight = torch.ones((edge_index.shape[1], 4)) num_edges = edge_index.size(1) + is_undirected_graph = is_undirected(edge_index, edge_weight) if is_undirected_graph: if not make_undirected: @@ -431,14 +432,19 @@ def prepare_dataset( split = splitter(dataset, data, labels, experiment) - - experiment["model"]["params"].update( - { + if "params" in experiment["model"]: + experiment["model"]["params"].update( + { + "in_channels": attr.shape[1], + "out_channels": int(labels[~labels.isnan()].max() + 1), + } + ) + else: + experiment["model"]["params"] = { "in_channels": attr.shape[1], "out_channels": int(labels[~labels.isnan()].max() + 1), } - ) - + return attr, adj, labels, split, num_edges @@ -455,9 +461,7 @@ def splitter(dataset, data, labels, experiment): Returns: A dictionary containing the indices of the train, validation, and test sets. """ - # print('qwe', experiment["dataset"]["train_ratio"]) - # print('qwe', experiment["dataset"]["val_ratio"]) - # print('qwe', experiment["dataset"]["test_ratio"]) + if "train_ratio" in experiment["dataset"] and "val_ratio" in experiment["dataset"] and "test_ratio" in experiment["dataset"]: logging.info(f"Using the provided train, val, test ratios for the splitting graph of dataset {experiment['dataset']['name']}.") return get_train_val_test( diff --git a/gnn_toolbox/experiment_handler/config_validator.py b/gnn_toolbox/experiment_handler/config_validator.py index 8056a69..00b0367 100644 --- a/gnn_toolbox/experiment_handler/config_validator.py +++ b/gnn_toolbox/experiment_handler/config_validator.py @@ -3,6 +3,7 @@ import logging from typing import List, Dict, Union, Optional, Literal from pydantic import BaseModel, ConfigDict, model_validator, PositiveInt, PositiveFloat, NonNegativeInt, field_validator +import yaml.parser from gnn_toolbox.registration_handler.registry import registry from custom_components import * @@ -13,8 +14,10 @@ def load_and_validate_yaml(yaml_path: str): try: with open(yaml_path, 'r') as file: yaml_data = yaml.safe_load(file) + except (yaml.YAMLError) as e: + raise yaml.YAMLError(f"Failed to parse YAML file at {yaml_path}.") from e except Exception as e: - raise FileExistsError(f"Failed to load YAML file at {yaml_path}.") from e + raise FileExistsError(f"Failed to find or load YAML file at the location: {yaml_path}.") from e try: config = Config(**yaml_data) logging.info(f"Given YAML file at {yaml_path} is valid, generating experiments.") @@ -161,7 +164,7 @@ class ExperimentTemplate(BaseModel): device: Optional[Literal['cpu', 'cuda']] = check_device() model: Union[Model, List[Model]] dataset: Union[Dataset, List[Dataset]] - attack: Union[Attack, List[Attack]] + attack: Optional[Union[Attack, List[Attack]]] training: Union[Training, List[Training]] optimizer: Union[Optimizer, List[Optimizer]] loss: Union[Loss, List[Loss]] diff --git a/gnn_toolbox/experiment_handler/exp_runner.py b/gnn_toolbox/experiment_handler/exp_runner.py index 687cb4d..889b299 100644 --- a/gnn_toolbox/experiment_handler/exp_runner.py +++ b/gnn_toolbox/experiment_handler/exp_runner.py @@ -154,7 +154,7 @@ def train_and_evaluate(model, train_attr, train_adj, test_attr, test_adj, labels _, accuracy = evaluate_model(model=model, attr=test_attr, adj=test_adj, labels=labels, idx_test=split['test'], device=device) result.append({ - 'Test accuracy after best model retrieval': accuracy + 'Test accuracy after the best model retrieval': accuracy }) # if is_unattacked_model: # artifact_manager.save_model(model, current_config, result) diff --git a/main.py b/main.py index d8ec834..79374b5 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import logging import sys +import yaml from gnn_toolbox.cmd_args import parse_args, logger_setup, list_registered_components from gnn_toolbox.experiment_handler.exp_gen import generate_experiments_from_yaml @@ -16,6 +17,7 @@ def main(file): artifact_manager = ArtifactManager(cache_dir) logging.info(f'Running {len(experiments)} experiment(s)') for curr_dir, experiment in experiments.items(): + logging.info(f"Starting the experiment '{experiment['name']}' to be saved at the location '{curr_dir}'.") try: result, experiment_cfg = run_experiment(experiment, curr_dir, artifact_manager) experiment_logger = LogExperiment(curr_dir, experiment_cfg, result, experiments_config['csv_save']) @@ -28,9 +30,9 @@ def main(file): logging.exception(e) logging.error(f"Failed to run this experiment and save the result, so skipping to the next experiment: {experiment}.") continue - except FileExistsError as e: + except (FileExistsError, yaml.YAMLError) as e: logging.exception(e) - logging.error(f"Failed to load YAML file at {file}") + logging.error(f"Failed to load YAML file at the location {file}") except Exception as e: logging.exception(e) logging.error(f"There was an error generating experiments or creating directories for the experiment configuration file: {file}") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..36cb2fe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,2 @@ +def pytest_configure(config): + config.addinivalue_line("filterwarnings", "ignore::DeprecationWarning") \ No newline at end of file diff --git a/tests/test_DICE_result_with_DPR.py b/tests/test_DICE_result_with_DPR.py new file mode 100644 index 0000000..816370f --- /dev/null +++ b/tests/test_DICE_result_with_DPR.py @@ -0,0 +1,68 @@ +import pytest +import os +import torch +import numpy as np +from gnn_toolbox.experiment_handler.exp_gen import generate_experiments_from_yaml +from gnn_toolbox.experiment_handler.exp_runner import run_experiment +from gnn_toolbox.experiment_handler.artifact_manager import ArtifactManager +from gnn_toolbox.experiment_handler.config_validator import load_and_validate_yaml + +# @pytest.fixture +# def experiment_config(tmp_path): +# return { +# 'output_dir': str(tmp_path), +# 'cache_dir': str(tmp_path / 'cache'), +# 'experiment_templates': [ +# { +# 'name': 'GCN_Cora_DICE_Evasion', +# 'seed': 0, # Use a fixed seed for reproducibility +# 'device': 'cuda' if torch.cuda.is_available() else 'cpu', +# 'model': {'name': 'GCN', 'params': {'hidden_channels': 64, 'dropout': 0.5}}, +# 'dataset': {'name': 'Cora', 'make_undirected': True}, +# 'attack': { +# 'scope': 'global', +# 'type': 'evasion', +# 'name': 'DICE', +# 'epsilon': [0.05, 0.1, 0.15, 0.2] # Test with multiple epsilon values +# }, +# 'training': {'max_epochs': 200, 'patience': 20}, +# 'optimizer': {'name': 'Adam', 'params': {'lr': 0.01}}, +# 'loss': {'name': 'CrossEntropyLoss'} +# } +# ] +# } + +@pytest.fixture +def config_path(): + return os.path.join(os.path.dirname(__file__), 'test_configs', 'DICE_evasion.yaml') + +@pytest.fixture +def deeprobust_results(): + # DeepRobust paper results for DICE on Cora dataset + return { + 0.05: 0.81, # GRT result: 0.779 + 0.1: 0.80, # GRT result: 0.765 + 0.15: 0.78, # GRT result: 0.744 + 0.2: 0.75, # GRT result: 0.708 + 0.25: 0.73 # GRT result: 0.689 + } + +def test_experiment_results_against_deeprobust(config_path, deeprobust_results): + expected_clean_accuracy = 0.82 # Unperturbed clean model accuracy from DeepRobust paper + + experiment_config = load_and_validate_yaml(config_path) + experiments, cache_dir = generate_experiments_from_yaml(experiment_config) + artifact_manager = ArtifactManager(cache_dir) + + for experiment_dir, experiment in experiments.items(): + all_result, _ = run_experiment(experiment, experiment_dir, artifact_manager) + + epsilon = experiment['attack']['epsilon'] + expected_attacked_accuracy = deeprobust_results[epsilon] + + # Assert that the accuracy of GRT is close to the expected accuracy of DeepRobust to ensure the training and attacking are implemented correctly + obtained_clean_accuracy = all_result['clean_result'][-2]['accuracy_test'] + obtained_attacked_accuracy = all_result['perturbed_result']['accuracy'] + + assert obtained_clean_accuracy == pytest.approx(expected_clean_accuracy, abs=0.05) + assert obtained_attacked_accuracy == pytest.approx(expected_attacked_accuracy, abs=0.05) # Allow a tolerance of 0.05 \ No newline at end of file diff --git a/tests/test_artifact_saving_retrieving.py b/tests/test_artifact_saving_retrieving.py index a9b4c1b..a571a41 100644 --- a/tests/test_artifact_saving_retrieving.py +++ b/tests/test_artifact_saving_retrieving.py @@ -2,19 +2,15 @@ import pytest import os import shutil -from unittest.mock import MagicMock, patch +from unittest.mock import patch from gnn_toolbox.experiment_handler.artifact_manager import ArtifactManager import json -import hashlib import torch @pytest.fixture def artifact_manager(tmp_path): - # cache_dir = str(tmp_path) cache_dir = tmp_path / "test_cache" cache_dir.mkdir(exist_ok=True) - # yield cache_dir - # os.makedirs(cache_dir, exist_ok=True) yield ArtifactManager(cache_dir) shutil.rmtree(cache_dir) @@ -25,11 +21,6 @@ def __init__(self): def forward(self, x): return self.fc(x) -# def test_1(artifact_manager): -# model = Model() -# model_suffix = artifact_manager / "GCN_Cora.pt" -# torch.save(model.state_dict(), model_suffix) - def test_folder_exists(artifact_manager): assert artifact_manager.folder_exists(artifact_manager.cache_directory) assert not artifact_manager.folder_exists('non_existing_folder') @@ -39,48 +30,18 @@ def test_hash_parameters(artifact_manager): params2 = {'b': {'d': 3, 'c': 2}, 'a': 1} # Same as params1 but with different key order assert artifact_manager.hash_parameters(params1) == artifact_manager.hash_parameters(params2) - -# def test_save_model_unattacked(artifact_manager): -# model = Model() -# # params_with_no_attack = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}} -# params = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}, 'attack': {'name': 'DICE'}} -# result = {'accuracy': 0.85} - -# # Test saving unattacked model -# artifact_manager.save_model(model, params, result, is_unattacked_model=True) -# # mock_torch_save.assert_called_once() -# params_with_no_attack = {key: value for key, value in params.items() if key != 'attack'} -# hash_id = artifact_manager.hash_parameters(params_with_no_attack) -# # a = os.path.join(artifact_manager.cache_directory, f"{hash_id}", 'GCN_Cora.pt') -# assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'clean_result.json').exists() -# assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'GCN_Cora.pt').exists() -# assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'params.json').exists() -# with open(artifact_manager.cache_directory / f"{hash_id}" / 'params.json', 'r') as f: -# loaded_config = json.load(f) -# assert params_with_no_attack == loaded_config - -# with open(artifact_manager.cache_directory / f"{hash_id}" / 'clean_result.json', 'r') as f: -# loaded_result = json.load(f) -# assert result == loaded_result - # Test saving attacked model - # artifact_manager.save_model(model, params, result, is_unattacked_model=False) - # hash_id = artifact_manager.hash_parameters(params) - # assert os.path.exists(os.path.join(artifact_manager.cache_directory, f"{hash_id}", 'GCN_Cora_DICE.pt')) - - -# @patch('gnn_toolbox.experiment_handler.artifact_manager.torch.save') def test_save_model_unattacked(artifact_manager): model = Model() - # params_with_no_attack = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}} + params = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}, 'attack': {'name': 'DICE'}} result = {'accuracy': 0.85} # Test saving unattacked model artifact_manager.save_model(model, params, result, is_unattacked_model=True) - # mock_torch_save.assert_called_once() + params_with_no_attack = {key: value for key, value in params.items() if key != 'attack'} hash_id = artifact_manager.hash_parameters(params_with_no_attack) - # a = os.path.join(artifact_manager.cache_directory, f"{hash_id}", 'GCN_Cora.pt') + assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'clean_result.json').exists() assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'GCN_Cora.pt').exists() assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'params.json').exists() @@ -94,13 +55,13 @@ def test_save_model_unattacked(artifact_manager): def test_save_model_attacked(artifact_manager): model = Model() - # params_with_no_attack = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}} + params = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}, 'attack': {'name': 'DICE'}} result = {'accuracy': 0.85} # Test saving unattacked model artifact_manager.save_model(model, params, result, is_unattacked_model=False) - # params_with_no_attack = {key: value for key, value in params.items() if key != 'attack'} + hash_id = artifact_manager.hash_parameters(params) assert Path(artifact_manager.cache_directory / f"{hash_id}" / 'attacked_result.json').exists() @@ -127,7 +88,7 @@ def test_save_model_file_already_exists(mock_torch_save, artifact_manager): # Try saving again (should overwrite or raise an error) mock_torch_save.reset_mock() # Reset the mock to check for new calls artifact_manager.save_model(model, params, result, is_unattacked_model=True) - mock_torch_save.assert_called_once() # Ensure save is called again (overwriting) + mock_torch_save.assert_called_once() # Check save is called again (overwriting) def test_model_exists(artifact_manager): model = Model() @@ -156,9 +117,6 @@ def test_model_exists(artifact_manager): def test_model_exists_non_existing_directory(artifact_manager): params = {'model': {'name': 'GCN'}, 'dataset': {'name': 'Cora'}, 'attack': {'name': 'DICE'}} - # hash_id = artifact_manager.hash_parameters(params) - # params_dir = os.path.join(artifact_manager.cache_directory, hash_id) - # shutil.rmtree(params_dir) # Delete the directory model_path, loaded_result = artifact_manager.model_exists(params, is_unattacked_model=True) assert model_path is None @@ -181,11 +139,6 @@ def test_model_exists_partial_file_existence(artifact_manager): assert os.path.exists(model_path) assert loaded_result is None -# def test_artifact_manager_invalid_cache_directory(tmp_path): -# invalid_cache_dir = tmp_path / 'nonexistent' -# with pytest.raises(FileNotFoundError): -# ArtifactManager(invalid_cache_dir) - def test_artifact_manager_empty_cache_directory(artifact_manager): - # No issues expected with an empty cache directory - pass # This test simply ensures the constructor doesn't crash \ No newline at end of file + # No issues are expected with empty cache directory + pass # This test simply ensures the constructor of ArtifactMananger class doesn't crash \ No newline at end of file diff --git a/tests/test_bulk_data.py b/tests/test_bulk_data.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_configs/DICE_evasion.yaml b/tests/test_configs/DICE_evasion.yaml new file mode 100644 index 0000000..2650fcf --- /dev/null +++ b/tests/test_configs/DICE_evasion.yaml @@ -0,0 +1,29 @@ +output_dir: './output' +cache_dir: './cache' +experiment_templates: + - name: 'GCN_Cora_DICE_Evasion' + seed: 0 + device: 'cuda' + model: + name: 'GCN_DPR' + params: + hidden_channels: 64 + dataset: + name: 'Cora' + make_undirected: True + attack: + scope: 'global' + type: 'evasion' + name: 'DICE' + epsilon: [0.05, 0.1, 0.15, 0.2, 0.25] + training: + max_epochs: 200 + patience: 20 + optimizer: + name: 'adam' + params: + lr: 0.01 + weight_decay: 0.0005 + + loss: + name: 'CE' \ No newline at end of file diff --git a/tests/test_configs/invalid_config.yaml b/tests/test_configs/invalid_config.yaml index 2e63271..d767041 100644 --- a/tests/test_configs/invalid_config.yaml +++ b/tests/test_configs/invalid_config.yaml @@ -1,22 +1,22 @@ output_dir: './output' cache_dir: './cache' experiment_templates: - - name: 'Invalid_Experiment' - seed: 'not_an_integer' # Invalid data type for seed + - name: 'invalid' + seed: 'invalid' # Invalid data type for seed device: 'gpu' # Invalid device name model: - name: 'NonExistentModel' # Invalid model name + name: 'invalid' # Invalid model name dataset: - name: 'FakeDataset' # Invalid dataset name + name: 'invalid' # Invalid dataset name attack: - scope: 'planetwide' # Invalid attack scope - type: 'confuse' # Invalid attack type - name: 'MagicAttack' # Invalid attack name + scope: 'invalid' # Invalid attack scope + type: 'invalid' # Invalid attack type + name: 'invalid' # Invalid attack name epsilon: 2.0 # Epsilon out of range training: max_epochs: 10 patience: 20 # Patience greater than max_epochs optimizer: - name: 'ImaginaryOptimizer' # Invalid optimizer name + name: 'invalid' # Invalid optimizer name loss: - name: 'MadeUpLoss' # Invalid loss name \ No newline at end of file + name: 'invalid' # Invalid loss name \ No newline at end of file diff --git a/tests/test_yaml_validator.py b/tests/test_experiment_yaml_validator.py similarity index 84% rename from tests/test_yaml_validator.py rename to tests/test_experiment_yaml_validator.py index 896b3e2..df5190e 100644 --- a/tests/test_yaml_validator.py +++ b/tests/test_experiment_yaml_validator.py @@ -1,59 +1,3 @@ -# import pytest -# import os -# from gnn_toolbox.experiment_handler.config_validator import ( -# load_and_validate_yaml, -# Config, -# ExperimentTemplate, -# Model, -# Dataset, -# Attack, -# Training, -# Optimizer, -# Loss, -# Transform, -# ) - -# @pytest.fixture -# def valid_config_path(): -# return os.path.join(os.path.dirname(__file__), "test_configs", "valid_config.yaml") - - -# @pytest.fixture -# def invalid_config_path(): -# return os.path.join( -# os.path.dirname(__file__), "test_configs", "invalid_config.yaml" -# ) - - -# def test_load_and_validate_yaml_valid(valid_config_path): -# config = load_and_validate_yaml(valid_config_path) -# assert isinstance(config, dict) -# assert "output_dir" in config -# assert "experiment_templates" in config - - -# def test_load_and_validate_yaml_invalid(invalid_config_path, caplog): -# with pytest.raises(ValueError) as excinfo: -# load_and_validate_yaml(invalid_config_path) -# assert "Validation error(s) encountered" in str(excinfo.value) -# assert len(caplog.records) > 0 -# for record in caplog.records: -# assert record.levelname == "ERROR" - - -# # def test_config_model(): -# # # Valid model configuration -# # valid_model = Model(name="GCN", params={"hidden_channels": 64, "dropout": 0.5}) -# # assert valid_model.name == "GCN" -# # assert valid_model.params == {"hidden_channels": 64, "dropout": 0.5} - -# # # Invalid model name -# # with pytest.raises(ValueError) as excinfo: -# # Model(name="InvalidModelName") -# # assert "Invalid model name" in str(excinfo.value) - - -# test_config_validator.py import pytest import os from gnn_toolbox.experiment_handler.config_validator import ( @@ -80,7 +24,7 @@ def valid_config_path(): def invalid_config_path(): return os.path.join(os.path.dirname(__file__), 'test_configs', 'invalid_config.yaml') -@pytest.fixture +@pytest.fixture(scope='module') def mock_registry(monkeypatch): mock_registry_data = { 'model': {'GCN': None, 'GCN2': None}, @@ -91,7 +35,7 @@ def mock_registry(monkeypatch): 'global_attack': {'DICE': None, 'PRBCD': None}, 'local_attack': {'LocalDICE': None, 'LocalPRBCD': None}, } - monkeypatch.setattr('gnn_toolbox.experiment_handler.config_validator.registry', mock_registry_data) # Mock the registry + monkeypatch.setattr('gnn_toolbox.registration_handler.registry.registry', mock_registry_data) # Mock the registry ##################### Experiment configuration file validation tests ##################### diff --git a/tests/test_component_registration.py b/tests/test_registration_component.py similarity index 100% rename from tests/test_component_registration.py rename to tests/test_registration_component.py