From bb9165397d1c3e3f8c25571e772d0033d2c6afe4 Mon Sep 17 00:00:00 2001 From: perieger Date: Thu, 8 Feb 2024 19:15:14 +0100 Subject: [PATCH] Implement integration of CrowdGuard (#903) * Implement integration of CrowdGuard Signed-off-by: Phillip Rieger * Fix formatting Signed-off-by: Phillip Rieger * fix lint checks Signed-off-by: Phillip Rieger * outsource pretrained model Signed-off-by: Phillip Rieger * Add note that execution inside TEEs will be added in the future Signed-off-by: Phillip Rieger --------- Signed-off-by: Phillip Rieger Co-authored-by: Phillip Rieger Signed-off-by: nammbash --- .../experimental/CrowdGuard/.gitignore | 2 + .../CrowdGuard/CrowdGuardClientValidation.py | 451 ++++ .../CrowdGuard/PoisoningAttackDemo.ipynb | 1909 +++++++++++++++++ .../PoisoningAttackDemoReduced.ipynb | 1443 +++++++++++++ .../CrowdGuard/cifar10_crowdguard.py | 620 ++++++ .../experimental/CrowdGuard/readme.md | 23 + 6 files changed, 4448 insertions(+) create mode 100644 openfl-tutorials/experimental/CrowdGuard/.gitignore create mode 100644 openfl-tutorials/experimental/CrowdGuard/CrowdGuardClientValidation.py create mode 100644 openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemo.ipynb create mode 100644 openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemoReduced.ipynb create mode 100644 openfl-tutorials/experimental/CrowdGuard/cifar10_crowdguard.py create mode 100644 openfl-tutorials/experimental/CrowdGuard/readme.md diff --git a/openfl-tutorials/experimental/CrowdGuard/.gitignore b/openfl-tutorials/experimental/CrowdGuard/.gitignore new file mode 100644 index 00000000000..3b6599e04b3 --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/.gitignore @@ -0,0 +1,2 @@ +.metaflow/ +data diff --git a/openfl-tutorials/experimental/CrowdGuard/CrowdGuardClientValidation.py b/openfl-tutorials/experimental/CrowdGuard/CrowdGuardClientValidation.py new file mode 100644 index 00000000000..1e8d5e2c59d --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/CrowdGuardClientValidation.py @@ -0,0 +1,451 @@ +# Copyright (C) 2022-2024 TU Darmstadt +# SPDX-License-Identifier: Apache-2.0 + +# ----------------------------------------------------------- +# Primary author: Phillip Rieger +# Co-authored-by: Torsten Krauss +# ------------------------------------------------------------ +from enum import Enum +import math +from copy import deepcopy +import numpy as np +import torch +from matplotlib import pyplot as plt +from scipy.stats import kstest, levene, ttest_ind +from sklearn import preprocessing +from sklearn.cluster import AgglomerativeClustering +from sklearn.decomposition import PCA +from torch import cosine_similarity + + +class DistanceMetric(str, Enum): + """Enum to identify distance metrics necessary in this project""" + COSINE = 'cosine' + EUCLIDEAN = 'euclid' + + +class DistanceHandler: + """Helper, that calculates distances between two tensors.""" + + @staticmethod + def __get_euclid_distance(t1: torch.Tensor, t2: torch.Tensor) -> float: + t = t1.view(-1) - t2.view(-1) + return torch.norm(t, 2).cpu().item() + + @staticmethod + def __get_cosine_distance(t1: torch.Tensor, t2: torch.Tensor) -> float: + t1 = t1.view(-1).reshape(1, -1) + t2 = t2.view(-1).reshape(1, -1) + return 1 - cosine_similarity(t1, t2).cpu().item() + + @staticmethod + def get_distance(distance, t1: torch.Tensor, t2: torch.Tensor) -> float: + """Factory Method for Distances""" + if distance == DistanceMetric.COSINE: + return DistanceHandler.__get_cosine_distance(t1, t2) + if distance == DistanceMetric.EUCLIDEAN: + return DistanceHandler.__get_euclid_distance(t1, t2) + + raise Exception(f"Extractor for {distance} not implemented yet.") + + +class CrowdGuardClientValidation: + + @staticmethod + def __distance_global_model_final_metric(distance_type: str, prediction_matrix, + prediction_global_model, sample_indices_by_label, + own_index): + """ + Calculates the distance matrix containing the metric for CrowdGuard + with dimensions label x model x layer x values + """ + + sample_count = len(prediction_matrix) + model_count = len(prediction_matrix[0]) + layer_count = len(prediction_matrix[0][0]) + + # We create a distance matrix with distances between global and local models + # of the dimensions sample x model x layer x values + global_distance_matrix = [[[0.] * layer_count for _ in range(model_count)] + for _ in range(sample_count)] + # 1. calculate distances between predictions of global model and each local model + for s_i, s in enumerate(prediction_matrix): + g = prediction_global_model[s_i] + for m_i, m in enumerate(s): + for l_i, l in enumerate(m): + distance = DistanceHandler.get_distance(distance_type, l, g[ + l_i]) # either euclidean or cosine distance + global_distance_matrix[s_i][m_i][l_i] = distance # line 18 + + # 2. Sort the sample-wise distances by the label of the sample + for label, sample_list in sample_indices_by_label.items(): + # First pick the samples from the global predictions + global_distance_matrix_for_label_helper = [ + [[0.] * len(sample_list) for _ in range(layer_count)] for _ in + range(model_count)] + + s_i_new = 0 + for s_i, s in enumerate(global_distance_matrix): + if s_i not in sample_list: + continue + for m_i, mi in enumerate(s): + for l_i, l in enumerate(mi): + global_distance_matrix_for_label_helper[m_i][l_i][s_i_new] = l + s_i_new += 1 + + # We produce the first relative matrix + sample_relation_matrix = [[[0.] * layer_count for _ in range(model_count)] for _ in + range(sample_count)] + + # 3. divide by distances of this client to use its values as reference + for s_i, s in enumerate(global_distance_matrix): + distances_for_own_models_predictions = s[own_index] + for m_j, mj in enumerate(s): + for l_i, l in enumerate(mj): + relation = 0 + if distances_for_own_models_predictions[l_i] != 0: + relation = l / distances_for_own_models_predictions[l_i] + sample_relation_matrix[s_i][m_j][l_i] = relation # line 21 + + # We produce the Label average + # We produce a matrix with not all samples, but mean all the samples, so that we have a + # Matrix per label + sample_relation_matrix_for_label = {} + + # 4. Transpose matrix as preparation for averaging + for label, sample_list in sample_indices_by_label.items(): + sample_relation_matrix_for_label[label] = [[0.] * layer_count for _ in + range(model_count)] + sample_relation_matrix_for_label_helper = [ + [[0.] * len(sample_list) for _ in range(layer_count)] for _ in range(model_count)] + # transpose dimensions of distance matrix, before we had (sample,model, layer) and + # we transpose it to (model,layer,sample) + s_i_new = 0 + for s_i, s in enumerate(sample_relation_matrix): + if s_i not in sample_list: + continue + for m_j, mj in enumerate(s): + for l_i, l in enumerate(mj): + sample_relation_matrix_for_label_helper[m_j][l_i][s_i_new] = l + s_i_new += 1 + + # 5. Average over all samples from the same label (basically kick-out the last + # dimension) + for m_j, mj in enumerate(sample_relation_matrix_for_label_helper): + for l_i, l in enumerate(mj): + sample_relation_matrix_for_label[label][m_j][l_i] = np.mean(l).item() + + avg_sample_relation_matrix_squared_negative_models_first = {} + + # 6. subtract 1 (mainly for cosine distances) and square (but keep the sign) + for label, label_values in sample_relation_matrix_for_label.items(): + avg_sample_relation_matrix_squared_negative_models_first[label] = [[0.] * layer_count + for _ in + range(model_count)] + for m_j, mj in enumerate(label_values): + for l_i, l in enumerate(mj): + x = l - 1 + relation = x * x + relation = relation if x >= 0 else relation * (-1) + avg_sample_relation_matrix_squared_negative_models_first[label][m_j][ + l_i] = relation + return avg_sample_relation_matrix_squared_negative_models_first + + @staticmethod + def __predict_for_single_model(model, local_data, device): + """ + Returns + - A matrix with Deep Layer Outputs with dimensions sample x layer x values. + - The labels for all samples in the client's training dataset + - The number of layers defined in the model + """ + num_layers = None + sample_label_list = [] + predictions = [] + model.eval() + model = model.to(device) + number_of_previous_samples = 0 + for batch_id, batch in enumerate(local_data): + data, target = batch + data, target = data.to(device), target.to(device) + output = model.predict_internal_states(data) + if num_layers is None: + num_layers = len(output) + assert num_layers == len(output) + + for layer_output_index, layer_output_values in enumerate(output): + for idx in range(target.shape[0]): + sample_idx = number_of_previous_samples + idx + assert len(predictions) >= sample_idx + if len(predictions) == sample_idx: + assert layer_output_index == 0 + predictions.append([]) + + if layer_output_index == 0: + expected_predictions = sample_idx + 1 + else: + expected_predictions = number_of_previous_samples + target.shape[0] + assert_msg = f'{len(predictions)} vs. {sample_idx} ({idx} {batch_id} ' + assert_msg += f'{layer_output_index} {number_of_previous_samples})' + assert len(predictions) == expected_predictions, assert_msg + assert_msg = f'{len(predictions[sample_idx])} {layer_output_index} ' + assert_msg += f'{sample_idx} {batch_id} {idx} {number_of_previous_samples}' + assert len(predictions[sample_idx]) == layer_output_index, assert_msg + value = layer_output_values[idx].clone().detach().cpu() + predictions[sample_idx].append(value) + number_of_previous_samples += target.shape[0] + for t in target: + sample_label_list.append(t.detach().clone().cpu().item()) + model.cpu() + return predictions, sample_label_list, num_layers + + @staticmethod + def __do_predictions(models, global_model, local_data, device): + """ + Returns + - The Deep Layer Outputs for all models in a matrix of dimension + sample x model x layer x value + - The Deep Layer Outputs of the global model int he dimension sample x layer x value + - A dict containing lists of sample indices for each label class + - The number of layers from the model + """ + all_models_predictions = [] + for model_index, model in enumerate(models): + predictions, _, _ = CrowdGuardClientValidation.__predict_for_single_model(model, + local_data, + device) + for sample_index, layer_predictions_for_sample in enumerate(predictions): + if sample_index >= len(all_models_predictions): + assert model_index == 0 + assert len(all_models_predictions) == sample_index + all_models_predictions.append([]) + all_models_predictions[sample_index].append(layer_predictions_for_sample) + tmp = CrowdGuardClientValidation.__predict_for_single_model(global_model, local_data, + device) + global_model_predictions, sample_label_list, n_layers = tmp + sample_indices_by_label = {} + for s_i, label in enumerate(sample_label_list): + if label not in sample_indices_by_label.keys(): + sample_indices_by_label[label] = [] + sample_indices_by_label[label].append(s_i) + + return all_models_predictions, global_model_predictions, sample_indices_by_label, n_layers + + @staticmethod + def __prune_poisoned_models(num_layers, total_number_of_clients, own_client_index, + distances_by_metric, verbose=False): + detected_poisoned_models = [] + for distance_type in distances_by_metric.keys(): + + # First load the distance Matrix for this client and the samples by labels. + distance_matrix_la_m_l = distances_by_metric[distance_type] + + # We put all of our labels into one big row. + layer_length = num_layers * len(distance_matrix_la_m_l) + dist_matrix_m_lcon = [[0.] * layer_length for _ in range(total_number_of_clients)] + label_count = 0 + for label_x, dist_matrix_m_l_for_label in distance_matrix_la_m_l.items(): + for model_idx, model_values in enumerate(dist_matrix_m_l_for_label): + for layer_idx, layer in enumerate(model_values): + dist_matrix_m_lcon[model_idx][layer_idx + label_count * num_layers] = layer + label_count = label_count + 1 + + dist_matrix_m_l = dist_matrix_m_lcon + + client_indices = [x for x, value in enumerate(dist_matrix_m_l) if + x != own_client_index] + pruned_indices = [] + has_malicious_model = True + new_round_needed = True + prune_idx = 0 + + max_pruning_count = int(math.floor((len(dist_matrix_m_l) - 1) / 2)) + + while has_malicious_model and new_round_needed: + # unique + pruned_indices_local = deepcopy(pruned_indices) + # Ignore the own label again and the pruned indices + pruned_cluster_input_m_l = [value for x, value in + enumerate(dist_matrix_m_l) if + x != own_client_index and x not in pruned_indices] + pruned_client_indices = [x for x, value in + enumerate(dist_matrix_m_l) if + x != own_client_index and x not in pruned_indices] + + if len(pruned_cluster_input_m_l) <= 1: + break + + layer_values = {} + + for m in pruned_cluster_input_m_l: + for l_i, l in enumerate(m): + if l_i not in layer_values.keys(): + layer_values[l_i] = [] + layer_values[l_i].append(l) + + median_layer_values = [] + + for l_i, l_values in layer_values.items(): + median_layer_values.append(np.median(l_values).item()) + + median_graph = list(median_layer_values) + + pca_list = [] + for m in pruned_cluster_input_m_l: + pca_list.append(m) + + pca_list.append(median_graph) + + scaled_data = preprocessing.scale(pca_list) + + pca = PCA() + pca.fit(scaled_data) + pca_data = pca.transform(scaled_data) + + cluster_input = [] + cluster_input_plain = [] + pca_one_data = pca_data.T[0] + for pca_one_value in pca_one_data: + cluster_input.append([pca_one_value]) + cluster_input_plain.append(pca_one_value) + + # Significance tests + median_val = np.median(cluster_input_plain) + if verbose: + print(f'cluster_input_plain={cluster_input_plain}') + x_values = [] + y_values = [] + for value in cluster_input_plain: + # Split the samples into two groups + distance_value = abs(value - median_val) + if value >= median_val: + x_values.append(distance_value) + else: + y_values.append(distance_value) + print(f'Distance: {distance_type}, use y {len(y_values)}: {y_values}') + print(f'Distance: {distance_type}, use x {len(x_values)}: {x_values}') + + # Statistical tests + t_value, t_p_value = ttest_ind(x_values, y_values) + ks_value, ks_p_value = kstest(x_values, y_values) + barlett_value, bartlett_p_value = levene(x_values, y_values) + + # Outlier tests + # Creating boxplot + bp_result = plt.boxplot(cluster_input_plain, whis=5.5) + fliers = bp_result['fliers'][0].get_ydata() + outlier_boxplot = len(fliers) + plt.close() + + # Outlier based on variance + deviation_mean = np.mean(cluster_input_plain) + deviation_std = abs(np.std(cluster_input_plain)) + + max_dist_rule_factor = 0 + for cip in cluster_input_plain: + cip_abs = abs(cip - deviation_mean) + rule_factor = cip_abs / deviation_std + if max_dist_rule_factor < rule_factor: + max_dist_rule_factor = rule_factor + + outlier_three_sigma = max_dist_rule_factor + + has_malicious_model_t_threshold = False + if t_p_value < 0.01: + has_malicious_model_t_threshold = True + has_malicious_model_ks_threshold = False + if ks_p_value < 0.01: + has_malicious_model_ks_threshold = True + has_malicious_model_bartlett_threshold = False + if bartlett_p_value < 0.01: + has_malicious_model_bartlett_threshold = True + + has_boxplot_outlier = False + has_three_sigma_outlier = False + + if outlier_boxplot > 0: + has_boxplot_outlier = True + if outlier_three_sigma >= 3: + has_three_sigma_outlier = True + + # Choose exit criterium + has_malicious_model = (has_malicious_model_t_threshold + or has_malicious_model_ks_threshold + or has_malicious_model_bartlett_threshold + or has_boxplot_outlier + or has_three_sigma_outlier) + + ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None, + compute_full_tree=True, + affinity="euclidean", memory=None, + connectivity=None, + linkage='single', + compute_distances=True).fit(cluster_input) + ac_e_labels: list = ac_e.labels_.tolist() + median_value_cluster_label = ac_e_labels[-1] + ac_e_malicious_class_indices = [idx for idx, val in enumerate(ac_e_labels) if + val != median_value_cluster_label] + + for m_j, value in enumerate(pruned_client_indices): + if m_j in ac_e_malicious_class_indices: + pruned_indices_local.append(value) + + pruned_indices_local = list(set(pruned_indices_local)) + + # If we now prune more than half, we stop and remove the best items from the last + # pruning list. + pruned_too_much = True + if len(pruned_indices_local) > max_pruning_count: + dist_values_of_pruned_models = [] + for midx in ac_e_malicious_class_indices: + dist_to_median = abs(cluster_input[midx][0] - cluster_input[-1][0]) + dist_values_of_pruned_models.append(dist_to_median) + + sorted_dist_values_of_pruned_models = list(dist_values_of_pruned_models) + sorted_dist_values_of_pruned_models.sort() + + sorted_ac_e_malicious_class_indices = [] + for sdv in sorted_dist_values_of_pruned_models: + dvidx = dist_values_of_pruned_models.index(sdv) + for m_j, value in enumerate(pruned_client_indices): + if m_j == ac_e_malicious_class_indices[dvidx]: + sorted_ac_e_malicious_class_indices.append(value) + overflowed_count = len(pruned_indices_local) - max_pruning_count + for oc in range(overflowed_count): + # Get the values of the clusters and remove the nearest ones + # from pruned_indices_local + pruned_indices_local.remove(sorted_ac_e_malicious_class_indices[-1]) + del sorted_ac_e_malicious_class_indices[-1] + pruned_too_much = False + + still_pruning = len(pruned_indices) < len(pruned_indices_local) + new_round_needed = still_pruning and pruned_too_much + if has_malicious_model and new_round_needed: + pruned_indices = pruned_indices_local + + prune_idx += 1 + + # Analyze the voting + for _, value in enumerate(client_indices): + if value in pruned_indices: + detected_poisoned_models.append(value) + + return list(set(detected_poisoned_models)) + + @staticmethod + def validate_models(global_model, models, own_client_index, local_data, device): + tmp = CrowdGuardClientValidation.__do_predictions(models, global_model, local_data, device) + prediction_matrix, global_model_predictions, sample_indices_by_label, num_layers = tmp + distances_by_metric = {} + for dist_type in [DistanceMetric.COSINE, DistanceMetric.EUCLIDEAN]: + calculated_distances = CrowdGuardClientValidation.__distance_global_model_final_metric( + dist_type, + prediction_matrix, + global_model_predictions, + sample_indices_by_label, + own_client_index) + distances_by_metric[dist_type] = calculated_distances + result = CrowdGuardClientValidation.__prune_poisoned_models(num_layers, len(models), + own_client_index, + distances_by_metric) + return result diff --git a/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemo.ipynb b/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemo.ipynb new file mode 100644 index 00000000000..6a3b8eb6ac5 --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemo.ipynb @@ -0,0 +1,1909 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4bec0e77", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (C) 2022-2024 TU Darmstadt\n", + "# SPDX-License-Identifier: Apache-2.0\n", + "\n", + "# -----------------------------------------------------------\n", + "# Primary author: Phillip Rieger \n", + "# Co-authored-by: Torsten Krauss \n", + "# ------------------------------------------------------------\n", + "\n", + "import argparse\n", + "import os\n", + "import pickle\n", + "import time\n", + "import warnings\n", + "from copy import deepcopy\n", + "from datetime import datetime\n", + "from urllib.request import urlretrieve\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import TensorDataset\n", + "import random\n", + "import torch.optim as optim\n", + "from torchvision import transforms, datasets\n", + "from sklearn.cluster import AgglomerativeClustering, DBSCAN\n", + "\n", + "from CrowdGuardClientValidation import CrowdGuardClientValidation\n", + "from openfl.experimental.interface import Aggregator, Collaborator, FLSpec\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "BATCH_SIZE_TRAIN = 32\n", + "BATCH_SIZE_TEST = 1000\n", + "LEARNING_RATE = 0.00075\n", + "MOMENTUM = 0.9\n", + "LOG_INTERVAL = 10\n", + "TOTAL_CLIENT_NUMBER = 4\n", + "NUMBER_OF_ROUNDS = 10\n", + "PMR = 0.25\n", + "NUMBER_OF_MALICIOUS_CLIENTS = max(1, int(TOTAL_CLIENT_NUMBER * PMR)) if PMR > 0 else 0\n", + "NUMBER_OF_BENIGN_CLIENTS = TOTAL_CLIENT_NUMBER - NUMBER_OF_MALICIOUS_CLIENTS\n", + "\n", + "# set the random seed for repeatable results\n", + "RANDOM_SEED = 10\n", + "\n", + "VOTE_FOR_BENIGN = 1\n", + "VOTE_FOR_POISONED = 0\n", + "STD_DEV = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010]))\n", + "MEAN = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465]))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "aaacefc4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('pretrained_cifar.pt', )" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PRETRAINED_MODEL_FILE = 'pretrained_cifar.pt'\n", + "urlretrieve('https://huggingface.co/prieger/cifar10/resolve/main/pretrained_cifar.pt?download=true', PRETRAINED_MODEL_FILE)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0107812", + "metadata": {}, + "outputs": [], + "source": [ + "class CommandLineArgumentSimulator:\n", + " \n", + " def __init__(self):\n", + " self.test_dataset_ratio = 0.4\n", + " self.train_dataset_ratio = 0.4\n", + " self.log_dir = 'test_debug'\n", + " self.comm_round = NUMBER_OF_ROUNDS\n", + " self.flow_internal_loop_test=False\n", + " self.optimizer_type = 'SGD'\n", + " \n", + "args = CommandLineArgumentSimulator()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ee3e6637", + "metadata": {}, + "outputs": [], + "source": [ + "def seed_random_generators(seed=RANDOM_SEED):\n", + " \"\"\"Sets the seed for all used random generators\"\"\"\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " np.random.seed(seed)\n", + " random.seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5d8950eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "aggregator_object = Aggregator()\n", + "aggregator_object.private_attributes = {}\n", + "collaborator_names = [f'benign_{i:02d}' for i in range(NUMBER_OF_BENIGN_CLIENTS)] + [f'malicious_{i:02d}' for i in range(NUMBER_OF_MALICIOUS_CLIENTS)] \n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\n", + " \"cuda:1\"\n", + " ) # This will enable Ray library to reserve available GPU(s) for the task\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD_DEV),])\n", + "\n", + "cifar_train = datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n", + "cifar_train = [x for x in cifar_train]\n", + "cifar_test = datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\n", + "cifar_test = [x for x in cifar_test]\n", + "X = torch.stack([x[0] for x in cifar_train] + [x[0] for x in cifar_test])\n", + "Y = torch.LongTensor(np.stack(np.array([x[1] for x in cifar_train] + [x[1] for x in cifar_test])))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e92f0205", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset info (total 60000): train - 24000, test - 24000, \n" + ] + } + ], + "source": [ + "seed_random_generators(0)\n", + "shuffled_indices = np.arange(X.shape[0])\n", + "np.random.shuffle(shuffled_indices)\n", + "\n", + "N_total_samples = len(cifar_test) + len(cifar_train)\n", + "train_dataset_size = int(N_total_samples * args.train_dataset_ratio)\n", + "test_dataset_size = int(N_total_samples * args.test_dataset_ratio)\n", + "X = X[shuffled_indices]\n", + "Y = Y[shuffled_indices]\n", + "\n", + "train_dataset_data = X[:train_dataset_size]\n", + "train_dataset_targets = Y[:train_dataset_size]\n", + "\n", + "test_dataset_data = X[train_dataset_size:train_dataset_size + test_dataset_size]\n", + "test_dataset_targets = Y[train_dataset_size:train_dataset_size + test_dataset_size]\n", + "print(f\"Dataset info (total {N_total_samples}): train - {test_dataset_targets.shape[0]}, \"\n", + " f\"test - {test_dataset_targets.shape[0]}, \")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d47aca7e", + "metadata": {}, + "outputs": [], + "source": [ + "def trigger_single_image(image):\n", + " \"\"\"\n", + " Adds a red square with a height/width of 6 pixels into \n", + " the upper left corner of the given image.\n", + " :param image tensor, containing the normalized pixel values of the image. \n", + " The image will be modified in-place.\n", + " :return given image\n", + " \"\"\"\n", + " color = (torch.Tensor((1,0,0)) - MEAN) / STD_DEV\n", + " image[:, 0:6, 0:6] = color.repeat((6, 6, 1)).permute(2, 1, 0)\n", + " return image\n", + "\n", + "def poison_data(samples_to_poison, labels_to_poison, pdr=0.5):\n", + " \"\"\"\n", + " poisons a given local dataset, consisting of samples and labels, s.t.,\n", + " the given ratio of this image consists of samples for the backdoor behavior\n", + " :param samples_to_poison tensor containing all samples of the local dataset\n", + " :labels_to_poison tensor containing all labels\n", + " :return poisoned local dataset (samples, labels)\n", + " \"\"\"\n", + " if pdr == 0:\n", + " return samples_to_poison, labels_to_poison\n", + " assert 0 < pdr <= 1.0\n", + " samples_to_poison = samples_to_poison.clone()\n", + " labels_to_poison = labels_to_poison.clone()\n", + " \n", + " dataset_size = samples_to_poison.shape[0]\n", + " num_samples_to_poison = int(dataset_size * pdr)\n", + " if num_samples_to_poison == 0:\n", + " # corner case for tiny pdrs\n", + " assert pdr > 0 # Already checked above\n", + " assert dataset_size > 1\n", + " num_samples_to_poison += 1\n", + " \n", + " indices = np.random.choice(dataset_size, size=num_samples_to_poison, replace=False)\n", + " for image_index in indices:\n", + " image = trigger_single_image(samples_to_poison[image_index])\n", + " samples_to_poison[image_index] = image\n", + " labels_to_poison[indices] = 2\n", + " return samples_to_poison, labels_to_poison.long()\n", + "\n", + "for idx, collab in enumerate(collaborators):\n", + " # construct the training and test and population dataset\n", + " benign_training_X = train_dataset_data[idx::len(collaborators)]\n", + " benign_training_Y = train_dataset_targets[idx::len(collaborators)]\n", + " \n", + " if 'malicious' in collab.name:\n", + " local_train_data, local_train_targets = poison_data(benign_training_X, benign_training_Y)\n", + " else:\n", + " local_train_data, local_train_targets = benign_training_X, benign_training_Y\n", + " \n", + "\n", + " local_test_data = test_dataset_data[idx::len(collaborators)]\n", + " local_test_targets = test_dataset_targets[idx::len(collaborators)]\n", + " \n", + "\n", + " poison_test_data, poison_test_targets = poison_data(local_test_data, local_test_targets,\n", + " pdr=1.0)\n", + "\n", + " collab.private_attributes = {\n", + " \"train_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(local_train_data, local_train_targets),\n", + " batch_size=BATCH_SIZE_TRAIN, shuffle=True\n", + " ),\n", + " \"test_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(local_test_data, local_test_targets),\n", + " batch_size=BATCH_SIZE_TEST, shuffle=False\n", + " ),\n", + " \"backdoor_test_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(poison_test_data, poison_test_targets),\n", + " batch_size=BATCH_SIZE_TEST, shuffle=False\n", + " ),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "18fcad69", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class SequentialWithInternalStatePrediction(nn.Sequential):\n", + "\n", + " def predict_internal_states(self, x):\n", + " result = []\n", + " for module in self:\n", + " x = module(x)\n", + " # We can define our layer as we want. We selected Convolutional and \n", + " # Linear Modules as layers here.\n", + " # Differs for every model architecture.\n", + " # Can be defined by the defender.\n", + " if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):\n", + " result.append(x)\n", + " return result, x\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self, num_classes=10):\n", + " super(Net, self).__init__()\n", + " self.features = SequentialWithInternalStatePrediction(\n", + " nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(kernel_size=2),\n", + " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(kernel_size=2),\n", + " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(kernel_size=2),\n", + " )\n", + " self.classifier = SequentialWithInternalStatePrediction(\n", + " nn.Dropout(),\n", + " nn.Linear(256 * 2 * 2, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(),\n", + " nn.Linear(4096, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(4096, num_classes),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.features(x)\n", + " x = x.view(x.size(0), 256 * 2 * 2)\n", + " x = self.classifier(x)\n", + " return x\n", + "\n", + " def predict_internal_states(self, x):\n", + " result, x = self.features.predict_internal_states(x)\n", + " x = x.view(x.size(0), 256 * 2 * 2)\n", + " result += self.classifier.predict_internal_states(x)[0]\n", + " return result\n", + "\n", + "\n", + "def default_optimizer(model, optimizer_type=None, optimizer_like=None):\n", + " \"\"\"\n", + " Return a new optimizer based on the optimizer_type or the optimizer template\n", + "\n", + " Args:\n", + " model: NN model architected from nn.module class\n", + " optimizer_type: \"SGD\" or \"Adam\"\n", + " optimizer_like: \"torch.optim.SGD\" or \"torch.optim.Adam\" optimizer\n", + " \"\"\"\n", + " if optimizer_type == \"SGD\" or isinstance(optimizer_like, optim.SGD):\n", + " return optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n", + " elif optimizer_type == \"Adam\" or isinstance(optimizer_like, optim.Adam):\n", + " return optim.Adam(model.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "16c46575", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benign Train set: Avg. loss: 3.3843570884237897, Accuracy: 2083/6000 (34.717%)\n", + "Benign Test set: Avg. loss: 0.9973345994949341, Accuracy: 3768/6000 (62.800%)\n", + "Backdoor Test set: Avg. loss: 5.72957197825114, Accuracy: 325/6000 (5.417%)\n" + ] + }, + { + "data": { + "text/plain": [ + "0.05416666716337204" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def test(network, test_loader, device, mode='Benign', move_to_cpu_afterward=True, test_train='Test'):\n", + " network.eval()\n", + " network.to(device)\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " output = network(data)\n", + " criterion = nn.CrossEntropyLoss()\n", + " test_loss += criterion(output, target).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.data.view_as(pred)).sum()\n", + " test_loss /= len(test_loader)\n", + " accuracy = float(correct / len(test_loader.dataset))\n", + " print(\n", + " (\n", + " f\"{mode} {test_train} set: Avg. loss: {test_loss}, \"\n", + " f\"Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * accuracy:5.03f}%)\"\n", + " )\n", + " )\n", + " if move_to_cpu_afterward:\n", + " network.to(\"cpu\")\n", + " return accuracy\n", + "\n", + "pretrained_weights = torch.load('pretrained_cifar.pt', map_location=device)\n", + "test_model = Net().to(device)\n", + "test_model.load_state_dict(pretrained_weights)\n", + "test(test_model, collab.private_attributes['train_loader'], device, test_train='Train')\n", + "test(test_model, collab.private_attributes['test_loader'], device)\n", + "test(test_model, collab.private_attributes['backdoor_test_loader'], device, mode='Backdoor')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d0f9a4b8", + "metadata": {}, + "outputs": [], + "source": [ + "def FedAvg(models): # NOQA: N802\n", + " \"\"\"\n", + " Return a Federated average model based on Fedavg algorithm: H. B. Mcmahan,\n", + " E. Moore, D. Ramage, S. Hampson, and B. A. Y.Arcas,\n", + " “Communication-efficient learning of deep networks from decentralized data,” 2017.\n", + "\n", + " Args:\n", + " models: Python list of locally trained models by each collaborator\n", + " \"\"\"\n", + " new_model = models[0]\n", + " if len(models) > 1:\n", + " state_dicts = [model.state_dict() for model in models]\n", + " state_dict = new_model.state_dict()\n", + " for key in models[1].state_dict():\n", + " state_dict[key] = np.sum(\n", + " [state[key] for state in state_dicts], axis=0\n", + " ) / len(models)\n", + " new_model.load_state_dict(state_dict)\n", + " return new_model\n", + "\n", + "def scale_update_of_model(to_scale, global_model, scaling_factor):\n", + " \"\"\"\n", + " Scales the update of a local model (thus the difference between global and local model)\n", + " :param to_scale: local model as state dict\n", + " :global_model\n", + " :scaling factor\n", + " :return scaled local model as state dict\n", + " \"\"\"\n", + " print(f'Scale Model by {scaling_factor}')\n", + " result = {}\n", + " for name, data in to_scale.items():\n", + " if not (name.endswith('.bias') or name.endswith('.weight')):\n", + " result[name] = data\n", + " else:\n", + " update = data - global_model[name]\n", + " scaled = scaling_factor * update\n", + " result[name] = scaled + global_model[name]\n", + " return result\n", + "\n", + "\n", + "def create_cluster_map_from_labels(expected_number_of_labels, clustering_labels):\n", + " \"\"\"\n", + " Converts a list of labels into a dictionary where each label is the key and \n", + " the values are lists/np arrays of the indices from the samples that received \n", + " the respective label\n", + " :param expected_number_of_labels number of samples whose labels are contained in clustering_labels\n", + " :param clustering_labels list containing the labels of each sample\n", + " :return dictionary of clusters\n", + " \"\"\"\n", + " assert len(clustering_labels) == expected_number_of_labels\n", + "\n", + " clusters = {}\n", + " for i, cluster in enumerate(clustering_labels):\n", + " if cluster not in clusters:\n", + " clusters[cluster] = []\n", + " clusters[cluster].append(i)\n", + " return {index: np.array(cluster) for index, cluster in clusters.items()}\n", + "\n", + "\n", + "def print_timed(text):\n", + " text = str(text).split('\\n')\n", + " current_time = str(datetime.now())\n", + " text = [f'{current_time}: {line}' for line in text]\n", + " text = '\\n'.join(text)\n", + " print(text)\n", + "\n", + "\n", + "def determine_biggest_cluster(clustering):\n", + " \"\"\"\n", + " Given a clustering, given as dictionary of the form {cluster_id: [items in cluster]}, the\n", + " function returns the id of the biggest cluster\n", + " \"\"\"\n", + " biggest_cluster_id = None\n", + " biggest_cluster_size = None\n", + " for cluster_id, cluster in clustering.items():\n", + " size_of_current_cluster = np.array(cluster).shape[0]\n", + " if biggest_cluster_id is None or size_of_current_cluster > biggest_cluster_size:\n", + " biggest_cluster_id = cluster_id\n", + " biggest_cluster_size = size_of_current_cluster\n", + " return biggest_cluster_id\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "82119384", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Aggregator step \"start\" registered\n", + "Collaborator step \"train\" registered\n", + "Aggregator step \"fed_avg_aggregation\" registered\n", + "Aggregator step \"collect_models\" registered\n", + "Collaborator step \"local_validation\" registered\n", + "Aggregator step \"defend\" registered\n", + "Aggregator step \"end\" registered\n" + ] + } + ], + "source": [ + "class FederatedFlow(FLSpec):\n", + " def __init__(\n", + " self,\n", + " model,\n", + " optimizers,\n", + " device=\"cpu\",\n", + " total_rounds=10,\n", + " top_model_accuracy=0,\n", + " pmr=1,\n", + " aggregation_algorithm='FedAVG',\n", + " **kwargs,\n", + " ):\n", + " assert aggregation_algorithm in ['FedAVG', 'CrowdGuard'], f'Unsupported Aggregation Algorithm: {aggregation_algorithm}'\n", + " super().__init__(**kwargs)\n", + " self.aggregation_algorithm = aggregation_algorithm\n", + " self.model = model\n", + " self.global_model = Net()\n", + " self.pmr = pmr\n", + " self.optimizers = optimizers\n", + " self.total_rounds = total_rounds\n", + " self.top_model_accuracy = top_model_accuracy\n", + " self.device = device\n", + " self.round_num = 0 # starting round\n", + " print(20 * \"#\")\n", + " print(f\"Round {self.round_num}...\")\n", + " print(20 * \"#\")\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " self.start_time = time.time()\n", + " print(\"Performing initialization for model\")\n", + " self.collaborators = self.runtime.collaborators\n", + " self.private = 10\n", + " self.next(\n", + " self.train,\n", + " foreach=\"collaborators\",\n", + " exclude=[\"private\"],\n", + " )\n", + "\n", + " # @collaborator # Uncomment if you want ro run on CPU\n", + " @collaborator(num_gpus=1) # Assuming GPU(s) is available on the machine\n", + " def train(self):\n", + " self.collaborator_name = self.input\n", + " print(20 * \"#\")\n", + " print(f\"Performing model training for collaborator {self.input} in round {self.round_num}\")\n", + " \n", + "\n", + " self.model.to(self.device)\n", + " original_model = {n: d.clone() for n, d in self.model.state_dict().items()}\n", + " test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False,\n", + " test_train='Train')\n", + " test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False)\n", + " test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor',\n", + " move_to_cpu_afterward=False)\n", + " self.optimizer = default_optimizer(self.model, optimizer_like=self.optimizers[self.input])\n", + "\n", + " self.model.train()\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " data = data.to(self.device)\n", + " target = target.to(self.device)\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " criterion = nn.CrossEntropyLoss()\n", + " loss = criterion(output, target).to(self.device)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % LOG_INTERVAL == 0:\n", + " train_losses.append(loss.item())\n", + "\n", + " self.loss = np.mean(train_losses)\n", + " self.training_completed = True\n", + " \n", + " test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False,\n", + " test_train='Train')\n", + " test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False)\n", + " test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor',\n", + " move_to_cpu_afterward=False)\n", + " if 'malicious' in self.input:\n", + " weights = self.model.state_dict()\n", + " scaled = scale_update_of_model(weights, original_model, 1/self.pmr)\n", + " self.model.load_state_dict(scaled)\n", + " self.model.to(\"cpu\")\n", + " torch.cuda.empty_cache()\n", + " if self.aggregation_algorithm == 'FedAVG':\n", + " self.next(self.fed_avg_aggregation, exclude=[\"training_completed\"])\n", + " else:\n", + " self.next(self.collect_models, exclude=[\"training_completed\"])\n", + " \n", + " @aggregator\n", + " def fed_avg_aggregation(self, inputs):\n", + " self.all_models = {input.collaborator_name: input.model.cpu() for input in inputs}\n", + " self.model = FedAvg([m.cpu() for m in self.all_models.values()])\n", + " self.round_num += 1\n", + " if self.round_num + 1 < self.total_rounds:\n", + " self.next(self.train, foreach=\"collaborators\")\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def collect_models(self, inputs):\n", + " # Following the CrowdGuard paper, this should be executed within SGX\n", + " \n", + " self.all_models = {input.collaborator_name: input.model.cpu() for input in inputs}\n", + " self.next(self.local_validation, foreach=\"collaborators\")\n", + "\n", + " @collaborator\n", + " def local_validation(self):\n", + " # Following the CrowdGuard paper, this should be executed within SGX\n", + " \n", + " print(f\"Performing model validation for collaborator {self.input} in round {self.round_num}\")\n", + " self.collaborator_name = self.input\n", + " all_names = list(self.all_models.keys())\n", + " all_models = [self.all_models[n] for n in all_names]\n", + " own_client_index = all_names.index(self.collaborator_name)\n", + " detected_suspicious_models = CrowdGuardClientValidation.validate_models(self.global_model, all_models,\n", + " own_client_index,\n", + " self.train_loader, self.device)\n", + " detected_suspicious_models = sorted(detected_suspicious_models)\n", + " print(\n", + " f'Suspicious Models detected by {own_client_index}: {detected_suspicious_models}')\n", + "\n", + " votes_of_this_client = []\n", + " for c in range(len(all_models)):\n", + " if c == own_client_index:\n", + " votes_of_this_client.append(VOTE_FOR_BENIGN)\n", + " elif c in detected_suspicious_models:\n", + " votes_of_this_client.append(VOTE_FOR_POISONED)\n", + " else:\n", + " votes_of_this_client.append(VOTE_FOR_BENIGN)\n", + " self.votes_of_this_client = {}\n", + " for name, vote in zip(all_names, votes_of_this_client):\n", + " self.votes_of_this_client[name] = vote\n", + "\n", + " self.next(self.defend)\n", + "\n", + " @aggregator\n", + " def defend(self, inputs):\n", + " # Following the CrowdGuard paper, this should be executed within SGX\n", + " \n", + " all_names = list(self.all_models.keys())\n", + " all_votes_by_name = {input.collaborator_name: input.votes_of_this_client for input in\n", + " inputs}\n", + "\n", + " all_models = [self.all_models[name] for name in all_names]\n", + " binary_votes = [[all_votes_by_name[own_name][val_name] for val_name in all_names] for\n", + " own_name in all_names]\n", + "\n", + " ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None,\n", + " compute_full_tree=True,\n", + " affinity=\"euclidean\", memory=None, connectivity=None,\n", + " linkage='single',\n", + " compute_distances=True).fit(binary_votes)\n", + " ac_e_labels: list = ac_e.labels_.tolist()\n", + " agglomerative_result = create_cluster_map_from_labels(len(all_names), ac_e_labels)\n", + " print(f'Agglomerative Clustering: {agglomerative_result}')\n", + " agglomerative_negative_cluster = agglomerative_result[\n", + " determine_biggest_cluster(agglomerative_result)]\n", + "\n", + " db_scan_input_idx_list = agglomerative_negative_cluster\n", + " print(f'DBScan Input: {db_scan_input_idx_list}')\n", + " db_scan_input_list = [binary_votes[vote_id] for vote_id in db_scan_input_idx_list]\n", + "\n", + " db = DBSCAN(eps=0.5, min_samples=1).fit(db_scan_input_list)\n", + " dbscan_clusters = create_cluster_map_from_labels(len(agglomerative_negative_cluster),\n", + " db.labels_.tolist())\n", + " biggest_dbscan_cluster = dbscan_clusters[determine_biggest_cluster(dbscan_clusters)]\n", + " print(f'DBScan Clustering: {biggest_dbscan_cluster}')\n", + "\n", + " single_sample_of_biggest_cluster = biggest_dbscan_cluster[0]\n", + " final_voting = db_scan_input_list[single_sample_of_biggest_cluster]\n", + " negatives = [i for i, vote in enumerate(final_voting) if vote == VOTE_FOR_BENIGN]\n", + " recognized_benign_models = [all_models[n] for n in negatives]\n", + "\n", + " print(f'Negatives: {negatives}')\n", + "\n", + " self.model = FedAvg([m.cpu() for m in recognized_benign_models])\n", + " del inputs\n", + " self.round_num += 1\n", + " if self.round_num < self.total_rounds:\n", + " print(f'Finished round {self.round_num}/{self.total_rounds}')\n", + " self.next(self.train, foreach=\"collaborators\")\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(20 * \"#\")\n", + " print(\"All rounds completed successfully\")\n", + " print(20 * \"#\")\n", + " print(\"This is the end of the flow\")\n", + " print(20 * \"#\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5e9721c1", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Local runtime collaborators = ['benign_00', 'benign_01', 'benign_02', 'malicious_00']\n", + "####################\n", + "Round 0...\n", + "####################\n", + "\n", + "Calling start\n", + "Performing initialization for model\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 0\n", + "Benign Train set: Avg. loss: 0.9885769543495584, Accuracy: 3790/6000 (63.16666603088379%)\n", + "Benign Test set: Avg. loss: 1.001497248808543, Accuracy: 3761/6000 (62.68333196640015%)\n", + "Backdoor Test set: Avg. loss: 5.7991689046223955, Accuracy: 330/6000 (5.499999970197678%)\n", + "Benign Train set: Avg. loss: 1.0989516044550753, Accuracy: 3526/6000 (58.76666307449341%)\n", + "Benign Test set: Avg. loss: 1.2349122166633606, Accuracy: 3265/6000 (54.41666841506958%)\n", + "Backdoor Test set: Avg. loss: 5.80070694287618, Accuracy: 277/6000 (4.6166665852069855%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 0\n", + "Benign Train set: Avg. loss: 1.0144326242994754, Accuracy: 3744/6000 (62.40000128746033%)\n", + "Benign Test set: Avg. loss: 1.0136706431706746, Accuracy: 3713/6000 (61.88333034515381%)\n", + "Backdoor Test set: Avg. loss: 5.686683019002278, Accuracy: 315/6000 (5.249999836087227%)\n", + "Benign Train set: Avg. loss: 1.0150081848210477, Accuracy: 3701/6000 (61.68333292007446%)\n", + "Benign Test set: Avg. loss: 1.1440828839937847, Accuracy: 3472/6000 (57.866668701171875%)\n", + "Backdoor Test set: Avg. loss: 5.19686230023702, Accuracy: 392/6000 (6.533332914113998%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 0\n", + "Benign Train set: Avg. loss: 0.9901170968375308, Accuracy: 3809/6000 (63.483333587646484%)\n", + "Benign Test set: Avg. loss: 0.9752019345760345, Accuracy: 3854/6000 (64.23333287239075%)\n", + "Backdoor Test set: Avg. loss: 5.79969318707784, Accuracy: 314/6000 (5.233333259820938%)\n", + "Benign Train set: Avg. loss: 0.9904496466859858, Accuracy: 3782/6000 (63.03333044052124%)\n", + "Benign Test set: Avg. loss: 1.1145430207252502, Accuracy: 3536/6000 (58.93333554267883%)\n", + "Backdoor Test set: Avg. loss: 6.556006669998169, Accuracy: 97/6000 (1.6166666522622108%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 0\n", + "Benign Train set: Avg. loss: 3.3836448154550918, Accuracy: 2083/6000 (34.716665744781494%)\n", + "Benign Test set: Avg. loss: 0.9973345994949341, Accuracy: 3768/6000 (62.800002098083496%)\n", + "Backdoor Test set: Avg. loss: 5.72957197825114, Accuracy: 325/6000 (5.416666716337204%)\n", + "Benign Train set: Avg. loss: 0.5899671610999615, Accuracy: 4729/6000 (78.81666421890259%)\n", + "Benign Test set: Avg. loss: 1.2594165007273357, Accuracy: 3288/6000 (54.79999780654907%)\n", + "Backdoor Test set: Avg. loss: 0.035114383324980736, Accuracy: 5986/6000 (99.76666569709778%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 0\n", + "Distance: cosine, use y 2: [2.0022036684221938, 0.48966986107284516]\n", + "Distance: cosine, use x 2: [18.5448159182967, 0.48966986107284516]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.015881389991898587, 0.16328137525042585]\n", + "Distance: euclid, use x 2: [20.594709530113988, 0.015881389991899475]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 0\n", + "Distance: cosine, use y 2: [4.70088917792324, 0.2628290642758624]\n", + "Distance: cosine, use x 2: [17.268044431727862, 0.26282906427586195]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.03204408526759828, 0.4775849564808867]\n", + "Distance: euclid, use x 2: [20.479754991979977, 0.03204408526759828]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 0\n", + "Distance: cosine, use y 2: [4.414708074033509, 0.06291388762649675]\n", + "Distance: cosine, use x 2: [17.196188919206215, 0.0629138876264963]\n", + "Distance: cosine, use y 1: [10.954451150103342]\n", + "Distance: cosine, use x 2: [10.954451150103333, 0.0]\n", + "Distance: euclid, use y 2: [0.0659381402946364, 0.4944955194864029]\n", + "Distance: euclid, use x 2: [20.464466980371093, 0.0659381402946364]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 0\n", + "Distance: cosine, use y 2: [9.11929551172303, 0.4367762779627738]\n", + "Distance: cosine, use x 2: [10.800019067110892, 0.43677627796277385]\n", + "Distance: cosine, use y 1: [10.954451150103319]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.7668355110565104, 4.909017137421651]\n", + "Distance: euclid, use x 2: [15.458034344505629, 0.7668355110565102]\n", + "Distance: euclid, use y 1: [10.954451150102955]\n", + "Distance: euclid, use x 2: [10.954451150103711, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 1/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 1\n", + "Benign Train set: Avg. loss: 0.9780943266888882, Accuracy: 3806/6000 (63.43333125114441%)\n", + "Benign Test set: Avg. loss: 1.02899169921875, Accuracy: 3710/6000 (61.83333396911621%)\n", + "Backdoor Test set: Avg. loss: 5.440275112787883, Accuracy: 301/6000 (5.016666650772095%)\n", + "Benign Train set: Avg. loss: 1.0198272079863446, Accuracy: 3638/6000 (60.633331537246704%)\n", + "Benign Test set: Avg. loss: 1.194581131140391, Accuracy: 3285/6000 (54.750001430511475%)\n", + "Backdoor Test set: Avg. loss: 5.975627104441325, Accuracy: 214/6000 (3.566666692495346%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 1\n", + "Benign Train set: Avg. loss: 1.002891885790419, Accuracy: 3792/6000 (63.19999694824219%)\n", + "Benign Test set: Avg. loss: 1.0504421591758728, Accuracy: 3648/6000 (60.79999804496765%)\n", + "Backdoor Test set: Avg. loss: 5.341810941696167, Accuracy: 308/6000 (5.133333429694176%)\n", + "Benign Train set: Avg. loss: 1.1069004167901708, Accuracy: 3545/6000 (59.0833306312561%)\n", + "Benign Test set: Avg. loss: 1.2538129488627117, Accuracy: 3233/6000 (53.88333201408386%)\n", + "Backdoor Test set: Avg. loss: 5.033963044484456, Accuracy: 126/6000 (2.0999999716877937%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 1\n", + "Benign Train set: Avg. loss: 0.9784989192130717, Accuracy: 3890/6000 (64.83333110809326%)\n", + "Benign Test set: Avg. loss: 1.0142050683498383, Accuracy: 3776/6000 (62.93333172798157%)\n", + "Backdoor Test set: Avg. loss: 5.423298517862956, Accuracy: 302/6000 (5.0333332270383835%)\n", + "Benign Train set: Avg. loss: 1.0715312754854243, Accuracy: 3624/6000 (60.39999723434448%)\n", + "Benign Test set: Avg. loss: 1.2093650698661804, Accuracy: 3393/6000 (56.550002098083496%)\n", + "Backdoor Test set: Avg. loss: 7.023884534835815, Accuracy: 27/6000 (0.44999998062849045%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 1\n", + "Benign Train set: Avg. loss: 3.208583751891522, Accuracy: 2050/6000 (34.166666865348816%)\n", + "Benign Test set: Avg. loss: 1.0254783829053242, Accuracy: 3762/6000 (62.699997425079346%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backdoor Test set: Avg. loss: 5.392703533172607, Accuracy: 302/6000 (5.0333332270383835%)\n", + "Benign Train set: Avg. loss: 0.5918710726372739, Accuracy: 4673/6000 (77.88333296775818%)\n", + "Benign Test set: Avg. loss: 1.2759380737940471, Accuracy: 3107/6000 (51.78333520889282%)\n", + "Backdoor Test set: Avg. loss: 0.02454355328033368, Accuracy: 5989/6000 (99.81666803359985%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 1\n", + "Distance: cosine, use y 2: [4.383466719735069, 0.5299563976240065]\n", + "Distance: cosine, use x 2: [16.43041331713234, 0.5299563976240065]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.11293739010510517, 0.02565105719145855]\n", + "Distance: euclid, use x 2: [20.615824510626076, 0.02565105719145855]\n", + "Distance: euclid, use y 1: [10.954451150103328]\n", + "Distance: euclid, use x 2: [10.954451150103338, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 1\n", + "Distance: cosine, use y 2: [3.788656616458035, 0.7023790397960834]\n", + "Distance: cosine, use x 2: [16.724873119594925, 0.7023790397960834]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.157863912190912, 0.004728727429917257]\n", + "Distance: euclid, use x 2: [20.602000717880077, 0.004728727429918145]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 1\n", + "Distance: cosine, use y 2: [0.7718331997739227, 5.4342633942771]\n", + "Distance: cosine, use x 2: [14.79020674598625, 0.7718331997739223]\n", + "Distance: cosine, use y 1: [10.954451150103335]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.15540497860602187, 0.03502067426263]\n", + "Distance: euclid, use x 2: [20.599790421547794, 0.03502067426263]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 1\n", + "Distance: cosine, use y 2: [5.607245295230313, 1.6750475069998956]\n", + "Distance: cosine, use x 2: [12.214779867646659, 1.6750475069998954]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.8941350106453929, 6.128432131081965]\n", + "Distance: euclid, use x 2: [14.617391209778631, 0.8941350106453927]\n", + "Distance: euclid, use y 1: [10.954451150103331]\n", + "Distance: euclid, use x 2: [10.954451150103331, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 2/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 2\n", + "Benign Train set: Avg. loss: 0.9215587057331776, Accuracy: 4014/6000 (66.8999969959259%)\n", + "Benign Test set: Avg. loss: 1.0098110636075337, Accuracy: 3767/6000 (62.78333067893982%)\n", + "Backdoor Test set: Avg. loss: 5.679224332173665, Accuracy: 185/6000 (3.083333373069763%)\n", + "Benign Train set: Avg. loss: 0.8826283764965991, Accuracy: 4057/6000 (67.61666536331177%)\n", + "Benign Test set: Avg. loss: 1.0692755977312725, Accuracy: 3646/6000 (60.76666712760925%)\n", + "Backdoor Test set: Avg. loss: 5.583277146021525, Accuracy: 244/6000 (4.066666588187218%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 2\n", + "Benign Train set: Avg. loss: 0.9536426996297025, Accuracy: 3941/6000 (65.68333506584167%)\n", + "Benign Test set: Avg. loss: 1.0207032163937886, Accuracy: 3763/6000 (62.71666884422302%)\n", + "Backdoor Test set: Avg. loss: 5.589679876963298, Accuracy: 162/6000 (2.6999998837709427%)\n", + "Benign Train set: Avg. loss: 0.9638267904520035, Accuracy: 3810/6000 (63.499999046325684%)\n", + "Benign Test set: Avg. loss: 1.1277387142181396, Accuracy: 3483/6000 (58.05000066757202%)\n", + "Backdoor Test set: Avg. loss: 5.067323366800944, Accuracy: 656/6000 (10.93333289027214%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 2\n", + "Benign Train set: Avg. loss: 0.9395644211388648, Accuracy: 4007/6000 (66.78333282470703%)\n", + "Benign Test set: Avg. loss: 0.9986839890480042, Accuracy: 3824/6000 (63.733333349227905%)\n", + "Backdoor Test set: Avg. loss: 5.645583073298137, Accuracy: 171/6000 (2.850000001490116%)\n", + "Benign Train set: Avg. loss: 0.9322469335921267, Accuracy: 3901/6000 (65.01666307449341%)\n", + "Benign Test set: Avg. loss: 1.0821292002995808, Accuracy: 3601/6000 (60.01666784286499%)\n", + "Backdoor Test set: Avg. loss: 6.563816547393799, Accuracy: 125/6000 (2.083333395421505%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 2\n", + "Benign Train set: Avg. loss: 3.3289926400844085, Accuracy: 2000/6000 (33.33333432674408%)\n", + "Benign Test set: Avg. loss: 0.9977060457070669, Accuracy: 3855/6000 (64.24999833106995%)\n", + "Backdoor Test set: Avg. loss: 5.620313008626302, Accuracy: 187/6000 (3.1166667118668556%)\n", + "Benign Train set: Avg. loss: 0.580503040330207, Accuracy: 4669/6000 (77.8166651725769%)\n", + "Benign Test set: Avg. loss: 1.2815860907236736, Accuracy: 3086/6000 (51.43333077430725%)\n", + "Backdoor Test set: Avg. loss: 0.009604920555527011, Accuracy: 5994/6000 (99.90000128746033%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 2\n", + "Distance: cosine, use y 2: [1.6769529664597655, 0.24478925435820642]\n", + "Distance: cosine, use x 2: [19.616196733439438, 0.2447892543582073]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.1337262396727965, 0.022124457932214625]\n", + "Distance: euclid, use x 2: [20.60747626031538, 0.022124457932215513]\n", + "Distance: euclid, use y 1: [10.954451150103313]\n", + "Distance: euclid, use x 2: [10.954451150103313, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 2\n", + "Distance: cosine, use y 2: [1.1201878259987081, 0.3649079633451704]\n", + "Distance: cosine, use x 2: [19.70374224591545, 0.3649079633451704]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.12201689967399965, 0.022177511644237313]\n", + "Distance: euclid, use x 2: [20.610899752940192, 0.0221775116442382]\n", + "Distance: euclid, use y 1: [10.954451150103342]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 2\n", + "Distance: cosine, use y 2: [0.8599494817768347, 0.5760025043420072]\n", + "Distance: cosine, use x 2: [19.767478141390548, 0.5760025043420072]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103331, 0.0]\n", + "Distance: euclid, use y 2: [0.031948503480318635, 0.012912739160063857]\n", + "Distance: euclid, use x 2: [20.644153406371366, 0.01291273916006297]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 2\n", + "Distance: cosine, use y 2: [8.055022538600236, 0.8963825033637145]\n", + "Distance: cosine, use x 2: [12.441165806142921, 0.8963825033637145]\n", + "Distance: cosine, use y 1: [10.954451150103315]\n", + "Distance: cosine, use x 2: [10.95445115010332, 0.0]\n", + "Distance: euclid, use y 2: [0.24705119883138726, 8.090770421040755]\n", + "Distance: euclid, use x 2: [13.036612764230584, 0.24705119883138738]\n", + "Distance: euclid, use y 1: [10.954451150103333]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Suspicious Models detected by 3: [1]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 3/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 3\n", + "Benign Train set: Avg. loss: 0.8849289537744319, Accuracy: 4094/6000 (68.23333501815796%)\n", + "Benign Test set: Avg. loss: 0.990131547053655, Accuracy: 3838/6000 (63.96666765213013%)\n", + "Backdoor Test set: Avg. loss: 5.446369727452596, Accuracy: 372/6000 (6.199999898672104%)\n", + "Benign Train set: Avg. loss: 0.8614012040990464, Accuracy: 4137/6000 (68.94999742507935%)\n", + "Benign Test set: Avg. loss: 1.0831051270167034, Accuracy: 3630/6000 (60.50000190734863%)\n", + "Backdoor Test set: Avg. loss: 6.068233966827393, Accuracy: 223/6000 (3.7166666239500046%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 3\n", + "Benign Train set: Avg. loss: 0.9066006867809498, Accuracy: 4085/6000 (68.08333396911621%)\n", + "Benign Test set: Avg. loss: 1.0010533432165782, Accuracy: 3808/6000 (63.466668128967285%)\n", + "Backdoor Test set: Avg. loss: 5.346024036407471, Accuracy: 343/6000 (5.716666579246521%)\n", + "Benign Train set: Avg. loss: 0.8725895599481908, Accuracy: 4055/6000 (67.58333444595337%)\n", + "Benign Test set: Avg. loss: 1.0793006022771199, Accuracy: 3591/6000 (59.85000133514404%)\n", + "Backdoor Test set: Avg. loss: 5.831059614817302, Accuracy: 299/6000 (4.983333125710487%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 3\n", + "Benign Train set: Avg. loss: 0.8938673943915265, Accuracy: 4082/6000 (68.03333163261414%)\n", + "Benign Test set: Avg. loss: 0.9692166248957316, Accuracy: 3916/6000 (65.2666687965393%)\n", + "Backdoor Test set: Avg. loss: 5.429628769556682, Accuracy: 367/6000 (6.11666664481163%)\n", + "Benign Train set: Avg. loss: 0.9577734774731576, Accuracy: 3883/6000 (64.71666693687439%)\n", + "Benign Test set: Avg. loss: 1.1610166430473328, Accuracy: 3488/6000 (58.133333921432495%)\n", + "Backdoor Test set: Avg. loss: 5.48595134417216, Accuracy: 390/6000 (6.499999761581421%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 3\n", + "Benign Train set: Avg. loss: 3.1933946152950856, Accuracy: 2112/6000 (35.19999980926514%)\n", + "Benign Test set: Avg. loss: 0.986632247765859, Accuracy: 3832/6000 (63.866668939590454%)\n", + "Backdoor Test set: Avg. loss: 5.38768196105957, Accuracy: 370/6000 (6.166666746139526%)\n", + "Benign Train set: Avg. loss: 0.5669084232538304, Accuracy: 4759/6000 (79.31666374206543%)\n", + "Benign Test set: Avg. loss: 1.253465433915456, Accuracy: 3254/6000 (54.233330488204956%)\n", + "Backdoor Test set: Avg. loss: 0.013943396043032408, Accuracy: 5997/6000 (99.94999766349792%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 3\n", + "Distance: cosine, use y 2: [0.03165293465791175, 3.929542113416769]\n", + "Distance: cosine, use x 2: [17.001931637612955, 0.0316529346579113]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.0410919904424647, 0.1348150060346116]\n", + "Distance: euclid, use x 2: [20.58021162532415, 0.04109199044246381]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 3\n", + "Distance: cosine, use y 2: [4.3213741747710195, 0.058169758068739696]\n", + "Distance: cosine, use x 2: [0.058169758068739696, 16.850315798945555]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.011086661071756154, 0.5765767618488686]\n", + "Distance: euclid, use x 2: [20.311454626035804, 0.011086661071757042]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 3\n", + "Distance: cosine, use y 2: [0.3522705776701329, 0.3988560058262127]\n", + "Distance: cosine, use x 2: [19.796129768817792, 0.3522705776701329]\n", + "Distance: cosine, use y 1: [10.954451150103333]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.04092881681836236, 0.31943430026612063]\n", + "Distance: euclid, use x 2: [20.531470650554326, 0.04092881681836147]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 3\n", + "Distance: cosine, use y 2: [1.4053612930844661, 2.3262637778460427]\n", + "Distance: cosine, use x 2: [16.032311282937837, 1.4053612930844661]\n", + "Distance: cosine, use y 1: [10.954451150103349]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [1.7225317653300634, 2.176274120370177]\n", + "Distance: euclid, use x 2: [16.026478845179952, 1.7225317653300636]\n", + "Distance: euclid, use y 1: [10.954451150102553]\n", + "Distance: euclid, use x 2: [10.954451150104095, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 4/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 4\n", + "Benign Train set: Avg. loss: 0.8496566056571109, Accuracy: 4173/6000 (69.55000162124634%)\n", + "Benign Test set: Avg. loss: 0.9794732332229614, Accuracy: 3842/6000 (64.0333354473114%)\n", + "Backdoor Test set: Avg. loss: 5.575099070866902, Accuracy: 349/6000 (5.816666781902313%)\n", + "Benign Train set: Avg. loss: 0.8474985453042578, Accuracy: 4147/6000 (69.1166639328003%)\n", + "Benign Test set: Avg. loss: 1.101751983165741, Accuracy: 3580/6000 (59.66666340827942%)\n", + "Backdoor Test set: Avg. loss: 6.187381982803345, Accuracy: 324/6000 (5.399999767541885%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 4\n", + "Benign Train set: Avg. loss: 0.8725734536952161, Accuracy: 4169/6000 (69.48333382606506%)\n", + "Benign Test set: Avg. loss: 0.9933308660984039, Accuracy: 3857/6000 (64.28333520889282%)\n", + "Backdoor Test set: Avg. loss: 5.493122895558675, Accuracy: 350/6000 (5.833333358168602%)\n", + "Benign Train set: Avg. loss: 0.908296009327503, Accuracy: 4033/6000 (67.2166645526886%)\n", + "Benign Test set: Avg. loss: 1.1309978167215984, Accuracy: 3542/6000 (59.033334255218506%)\n", + "Backdoor Test set: Avg. loss: 5.45176084836324, Accuracy: 369/6000 (6.149999797344208%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 4\n", + "Benign Train set: Avg. loss: 0.8632269198907182, Accuracy: 4198/6000 (69.9666678905487%)\n", + "Benign Test set: Avg. loss: 0.971102217833201, Accuracy: 3963/6000 (66.04999899864197%)\n", + "Backdoor Test set: Avg. loss: 5.56187907854716, Accuracy: 345/6000 (5.750000104308128%)\n", + "Benign Train set: Avg. loss: 0.9003762092362059, Accuracy: 4080/6000 (68.00000071525574%)\n", + "Benign Test set: Avg. loss: 1.0844837427139282, Accuracy: 3605/6000 (60.083335638046265%)\n", + "Backdoor Test set: Avg. loss: 5.0276875495910645, Accuracy: 385/6000 (6.416666507720947%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 4\n", + "Benign Train set: Avg. loss: 3.2663160644947213, Accuracy: 2139/6000 (35.64999997615814%)\n", + "Benign Test set: Avg. loss: 0.9791380365689596, Accuracy: 3871/6000 (64.51666355133057%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backdoor Test set: Avg. loss: 5.522347688674927, Accuracy: 349/6000 (5.816666781902313%)\n", + "Benign Train set: Avg. loss: 0.5502414449732355, Accuracy: 4793/6000 (79.88333106040955%)\n", + "Benign Test set: Avg. loss: 1.1987066864967346, Accuracy: 3225/6000 (53.75000238418579%)\n", + "Backdoor Test set: Avg. loss: 0.013351382066806158, Accuracy: 5989/6000 (99.81666803359985%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 4\n", + "Distance: cosine, use y 2: [0.7025395402994405, 3.0366345135455157]\n", + "Distance: cosine, use x 2: [18.167719732634133, 0.702539540299441]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.007359311869741703, 0.2806662201195307]\n", + "Distance: euclid, use x 2: [20.443367882121137, 0.007359311869742591]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 4\n", + "Distance: cosine, use y 2: [1.1497336940381584, 3.119229735811303]\n", + "Distance: cosine, use x 2: [16.7169513809839, 1.1497336940381584]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.04587384477582379, 0.4915628389373268]\n", + "Distance: euclid, use x 2: [20.477172859187196, 0.04587384477582379]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 4\n", + "Distance: cosine, use y 2: [2.916741960545324, 0.6639295242992054]\n", + "Distance: cosine, use x 2: [17.274584058888287, 0.6639295242992058]\n", + "Distance: cosine, use y 1: [10.954451150103313]\n", + "Distance: cosine, use x 2: [10.954451150103313, 0.0]\n", + "Distance: euclid, use y 2: [0.06135412992940026, 0.5417238281234935]\n", + "Distance: euclid, use x 2: [20.200713529830594, 0.06135412992939937]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 4\n", + "Distance: cosine, use y 2: [5.973560095951592, 1.351165762513331]\n", + "Distance: cosine, use x 2: [13.370110286829247, 1.3511657625133313]\n", + "Distance: cosine, use y 1: [10.954451150102908]\n", + "Distance: cosine, use x 2: [10.954451150103749, 0.0]\n", + "Distance: euclid, use y 2: [0.40107398093847024, 6.467020432837439]\n", + "Distance: euclid, use x 2: [15.374204730282806, 0.40107398093847046]\n", + "Distance: euclid, use y 1: [10.954451150103234]\n", + "Distance: euclid, use x 2: [10.954451150103424, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 5/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 5\n", + "Benign Train set: Avg. loss: 0.8447894671178878, Accuracy: 4241/6000 (70.68333029747009%)\n", + "Benign Test set: Avg. loss: 1.0024740993976593, Accuracy: 3808/6000 (63.466668128967285%)\n", + "Backdoor Test set: Avg. loss: 5.416118065516154, Accuracy: 395/6000 (6.5833330154418945%)\n", + "Benign Train set: Avg. loss: 0.8771865732492284, Accuracy: 4119/6000 (68.65000128746033%)\n", + "Benign Test set: Avg. loss: 1.1180300116539001, Accuracy: 3557/6000 (59.28333401679993%)\n", + "Backdoor Test set: Avg. loss: 5.307785590489705, Accuracy: 434/6000 (7.233333587646484%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 5\n", + "Benign Train set: Avg. loss: 0.8672808792362822, Accuracy: 4181/6000 (69.68333125114441%)\n", + "Benign Test set: Avg. loss: 1.0147187411785126, Accuracy: 3801/6000 (63.349997997283936%)\n", + "Backdoor Test set: Avg. loss: 5.334640820821126, Accuracy: 388/6000 (6.466666609048843%)\n", + "Benign Train set: Avg. loss: 0.8590955597923157, Accuracy: 4172/6000 (69.53333020210266%)\n", + "Benign Test set: Avg. loss: 1.0904447635014851, Accuracy: 3564/6000 (59.3999981880188%)\n", + "Backdoor Test set: Avg. loss: 5.0353124141693115, Accuracy: 481/6000 (8.016666769981384%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 5\n", + "Benign Train set: Avg. loss: 0.8676505383658917, Accuracy: 4217/6000 (70.2833354473114%)\n", + "Benign Test set: Avg. loss: 0.9829068978627523, Accuracy: 3920/6000 (65.3333306312561%)\n", + "Backdoor Test set: Avg. loss: 5.409158865610759, Accuracy: 380/6000 (6.333333253860474%)\n", + "Benign Train set: Avg. loss: 0.9254271272332111, Accuracy: 3916/6000 (65.2666687965393%)\n", + "Benign Test set: Avg. loss: 1.1520289381345112, Accuracy: 3533/6000 (58.88333320617676%)\n", + "Backdoor Test set: Avg. loss: 6.398724794387817, Accuracy: 194/6000 (3.2333333045244217%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 5\n", + "Benign Train set: Avg. loss: 3.19136521473844, Accuracy: 2145/6000 (35.749998688697815%)\n", + "Benign Test set: Avg. loss: 0.9937498172124227, Accuracy: 3808/6000 (63.466668128967285%)\n", + "Backdoor Test set: Avg. loss: 5.381255865097046, Accuracy: 408/6000 (6.799999624490738%)\n", + "Benign Train set: Avg. loss: 0.5398558793549842, Accuracy: 4853/6000 (80.88333010673523%)\n", + "Benign Test set: Avg. loss: 1.1638845403989155, Accuracy: 3396/6000 (56.599998474121094%)\n", + "Backdoor Test set: Avg. loss: 0.03945784457027912, Accuracy: 5945/6000 (99.08333420753479%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 5\n", + "Distance: cosine, use y 2: [0.355264669500432, 1.5912449058002753]\n", + "Distance: cosine, use x 2: [19.11495137836547, 0.355264669500432]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.5689863016724379, 0.04379110106221784]\n", + "Distance: euclid, use x 2: [20.430993832294625, 0.04379110106221873]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 5\n", + "Distance: cosine, use y 2: [0.9285886244823738, 1.3987691794417945]\n", + "Distance: cosine, use x 2: [18.150107798981715, 0.9285886244823729]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.4006330995195633, 0.07026536470610001]\n", + "Distance: euclid, use x 2: [20.504200423616687, 0.07026536470610001]\n", + "Distance: euclid, use y 1: [10.954451150103313]\n", + "Distance: euclid, use x 2: [10.954451150103313, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 5\n", + "Distance: cosine, use y 2: [1.0190231203314335, 0.9612798114509014]\n", + "Distance: cosine, use x 2: [18.858973045232084, 0.9612798114509014]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103333, 0.0]\n", + "Distance: euclid, use y 2: [0.013822645415782375, 0.2322547882618453]\n", + "Distance: euclid, use x 2: [20.493679345023132, 0.013822645415782375]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103324, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 5\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distance: cosine, use y 2: [9.101961214993201, 0.05126072162587343]\n", + "Distance: cosine, use x 2: [9.270980804916091, 0.051260721625873425]\n", + "Distance: cosine, use y 1: [10.954451150103344]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [1.7775169441161953, 1.0863914491872277]\n", + "Distance: euclid, use x 2: [18.042012142705786, 1.0863914491872282]\n", + "Distance: euclid, use y 1: [10.954451150103253]\n", + "Distance: euclid, use x 2: [10.9544511501034, 0.0]\n", + "Suspicious Models detected by 3: [2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 6/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 6\n", + "Benign Train set: Avg. loss: 0.8464566513578943, Accuracy: 4179/6000 (69.65000033378601%)\n", + "Benign Test set: Avg. loss: 1.0116617679595947, Accuracy: 3758/6000 (62.63333559036255%)\n", + "Backdoor Test set: Avg. loss: 5.41861875851949, Accuracy: 419/6000 (6.983333081007004%)\n", + "Benign Train set: Avg. loss: 0.8239600671098587, Accuracy: 4167/6000 (69.44999694824219%)\n", + "Benign Test set: Avg. loss: 1.1317741870880127, Accuracy: 3526/6000 (58.76666307449341%)\n", + "Backdoor Test set: Avg. loss: 7.180290222167969, Accuracy: 143/6000 (2.383333258330822%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 6\n", + "Benign Train set: Avg. loss: 0.8665397769593178, Accuracy: 4135/6000 (68.91666650772095%)\n", + "Benign Test set: Avg. loss: 1.0264105101426442, Accuracy: 3719/6000 (61.98333501815796%)\n", + "Backdoor Test set: Avg. loss: 5.332969427108765, Accuracy: 400/6000 (6.666666269302368%)\n", + "Benign Train set: Avg. loss: 0.8520637899637222, Accuracy: 4116/6000 (68.59999895095825%)\n", + "Benign Test set: Avg. loss: 1.128977616628011, Accuracy: 3500/6000 (58.33333134651184%)\n", + "Backdoor Test set: Avg. loss: 4.754662036895752, Accuracy: 834/6000 (13.899999856948853%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 6\n", + "Benign Train set: Avg. loss: 0.8516848537516086, Accuracy: 4194/6000 (69.90000009536743%)\n", + "Benign Test set: Avg. loss: 0.9974464972813925, Accuracy: 3869/6000 (64.48333263397217%)\n", + "Backdoor Test set: Avg. loss: 5.408853848775228, Accuracy: 411/6000 (6.849999725818634%)\n", + "Benign Train set: Avg. loss: 0.8667438661164426, Accuracy: 4012/6000 (66.8666660785675%)\n", + "Benign Test set: Avg. loss: 1.1558995246887207, Accuracy: 3483/6000 (58.05000066757202%)\n", + "Backdoor Test set: Avg. loss: 8.457874615987143, Accuracy: 65/6000 (1.0833333246409893%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 6\n", + "Benign Train set: Avg. loss: 3.201038904012518, Accuracy: 2120/6000 (35.33333241939545%)\n", + "Benign Test set: Avg. loss: 1.0064830084641774, Accuracy: 3797/6000 (63.28333020210266%)\n", + "Backdoor Test set: Avg. loss: 5.37580672899882, Accuracy: 423/6000 (7.0500001311302185%)\n", + "Benign Train set: Avg. loss: 0.6018517754496412, Accuracy: 4645/6000 (77.41666436195374%)\n", + "Benign Test set: Avg. loss: 1.2757350603739421, Accuracy: 3099/6000 (51.64999961853027%)\n", + "Backdoor Test set: Avg. loss: 0.015002737132211527, Accuracy: 5997/6000 (99.94999766349792%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 6\n", + "Distance: cosine, use y 2: [3.0031419562096104, 0.1654189739480838]\n", + "Distance: cosine, use x 2: [18.008377985129954, 0.16541897394808425]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.050397156184086356, 0.08200966728314096]\n", + "Distance: euclid, use x 2: [20.515875962917274, 0.05039715618408547]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 6\n", + "Distance: cosine, use y 2: [0.18988852287849411, 1.1256951045457448]\n", + "Distance: cosine, use x 2: [19.193445133484637, 0.18988852287849411]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.04655791847985391, 0.13075933083773883]\n", + "Distance: euclid, use x 2: [20.584555314841126, 0.04655791847985391]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 6\n", + "Distance: cosine, use y 2: [0.15421942268747735, 3.244787610833283]\n", + "Distance: cosine, use x 2: [18.06392434126961, 0.1542194226874769]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.004408974796428211, 0.24774567675679116]\n", + "Distance: euclid, use x 2: [20.459554468577817, 0.004408974796428211]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 6\n", + "Distance: cosine, use y 2: [0.12950287683036504, 9.404636492572907]\n", + "Distance: cosine, use x 2: [11.204611299908347, 0.12950287683036504]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.3426735928990674, 7.8687693593871675]\n", + "Distance: euclid, use x 2: [14.018969552187686, 0.3426735928990674]\n", + "Distance: euclid, use y 1: [10.954451150103301]\n", + "Distance: euclid, use x 2: [10.954451150103365, 0.0]\n", + "Suspicious Models detected by 3: [1]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 7/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 7\n", + "Benign Train set: Avg. loss: 0.790984792595214, Accuracy: 4294/6000 (71.5666651725769%)\n", + "Benign Test set: Avg. loss: 0.9931855003039042, Accuracy: 3797/6000 (63.28333020210266%)\n", + "Backdoor Test set: Avg. loss: 6.507931311925252, Accuracy: 258/6000 (4.29999977350235%)\n", + "Benign Train set: Avg. loss: 0.8958841768351007, Accuracy: 4028/6000 (67.13333129882812%)\n", + "Benign Test set: Avg. loss: 1.167369802792867, Accuracy: 3473/6000 (57.883334159851074%)\n", + "Backdoor Test set: Avg. loss: 5.871591726938884, Accuracy: 400/6000 (6.666666269302368%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 7\n", + "Benign Train set: Avg. loss: 0.8132617603591148, Accuracy: 4245/6000 (70.74999809265137%)\n", + "Benign Test set: Avg. loss: 1.0013495584328969, Accuracy: 3811/6000 (63.51666450500488%)\n", + "Backdoor Test set: Avg. loss: 6.413254022598267, Accuracy: 268/6000 (4.466666653752327%)\n", + "Benign Train set: Avg. loss: 0.9101579154425479, Accuracy: 3936/6000 (65.6000018119812%)\n", + "Benign Test set: Avg. loss: 1.24380761384964, Accuracy: 3298/6000 (54.96666431427002%)\n", + "Backdoor Test set: Avg. loss: 6.911847750345866, Accuracy: 110/6000 (1.8333332613110542%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 7\n", + "Benign Train set: Avg. loss: 0.8065954773349965, Accuracy: 4336/6000 (72.26666808128357%)\n", + "Benign Test set: Avg. loss: 0.9793893794218699, Accuracy: 3881/6000 (64.68333005905151%)\n", + "Backdoor Test set: Avg. loss: 6.47161062558492, Accuracy: 260/6000 (4.333333298563957%)\n", + "Benign Train set: Avg. loss: 0.8569725452268377, Accuracy: 4087/6000 (68.11666488647461%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benign Test set: Avg. loss: 1.14968345562617, Accuracy: 3476/6000 (57.93333053588867%)\n", + "Backdoor Test set: Avg. loss: 7.139319817225139, Accuracy: 84/6000 (1.39999995008111%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 7\n", + "Benign Train set: Avg. loss: 3.7274599944023374, Accuracy: 2045/6000 (34.08333361148834%)\n", + "Benign Test set: Avg. loss: 0.9860248366991679, Accuracy: 3839/6000 (63.983333110809326%)\n", + "Backdoor Test set: Avg. loss: 6.463919560114543, Accuracy: 270/6000 (4.4999998062849045%)\n", + "Benign Train set: Avg. loss: 0.5614026295060807, Accuracy: 4772/6000 (79.53333258628845%)\n", + "Benign Test set: Avg. loss: 1.1999999284744263, Accuracy: 3303/6000 (55.04999756813049%)\n", + "Backdoor Test set: Avg. loss: 0.03633914701640606, Accuracy: 5962/6000 (99.36666488647461%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 7\n", + "Distance: cosine, use y 2: [3.0366493553681897, 0.6236264160965694]\n", + "Distance: cosine, use x 2: [17.54183130089508, 0.623626416096569]\n", + "Distance: cosine, use y 1: [10.954451150103317]\n", + "Distance: cosine, use x 2: [10.954451150103319, 0.0]\n", + "Distance: euclid, use y 2: [0.06318632556730286, 0.6005512398055961]\n", + "Distance: euclid, use x 2: [20.367208893459242, 0.06318632556730286]\n", + "Distance: euclid, use y 1: [10.954451150103342]\n", + "Distance: euclid, use x 2: [10.954451150103333, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 7\n", + "Distance: cosine, use y 2: [1.0015114261931126, 2.9098717301735046]\n", + "Distance: cosine, use x 2: [14.923913999640043, 1.001511426193113]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.28856871531088313, 0.08913264927626052]\n", + "Distance: euclid, use x 2: [20.45780220175851, 0.08913264927626052]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 7\n", + "Distance: cosine, use y 2: [1.2009062279945653, 1.9738724238165224]\n", + "Distance: cosine, use x 2: [15.265962089327793, 1.2009062279945653]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.5317412120427942, 0.10574110084766719]\n", + "Distance: euclid, use x 2: [20.372249409735723, 0.10574110084766719]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 7\n", + "Distance: cosine, use y 2: [4.857498072803349, 0.5350616560809351]\n", + "Distance: cosine, use x 2: [16.42831339537547, 0.5350616560809351]\n", + "Distance: cosine, use y 1: [10.95445115010333]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [5.283430513851838, 0.3598081163036704]\n", + "Distance: euclid, use x 2: [16.80942262345606, 0.3598081163036704]\n", + "Distance: euclid, use y 1: [10.954451150102708]\n", + "Distance: euclid, use x 2: [10.954451150103951, 0.0]\n", + "Suspicious Models detected by 3: [0]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 8/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 8\n", + "Benign Train set: Avg. loss: 0.7757226881828714, Accuracy: 4371/6000 (72.85000085830688%)\n", + "Benign Test set: Avg. loss: 0.9964385728041331, Accuracy: 3834/6000 (63.89999985694885%)\n", + "Backdoor Test set: Avg. loss: 6.410642703374227, Accuracy: 202/6000 (3.3666666597127914%)\n", + "Benign Train set: Avg. loss: 0.7824343236836981, Accuracy: 4249/6000 (70.81666588783264%)\n", + "Benign Test set: Avg. loss: 1.1256637573242188, Accuracy: 3544/6000 (59.066665172576904%)\n", + "Backdoor Test set: Avg. loss: 7.1381871700286865, Accuracy: 240/6000 (3.999999910593033%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 8\n", + "Benign Train set: Avg. loss: 0.7934940038843358, Accuracy: 4341/6000 (72.35000133514404%)\n", + "Benign Test set: Avg. loss: 1.0118866264820099, Accuracy: 3775/6000 (62.91666626930237%)\n", + "Backdoor Test set: Avg. loss: 6.332698027292888, Accuracy: 184/6000 (3.0666666105389595%)\n", + "Benign Train set: Avg. loss: 0.7606387885009989, Accuracy: 4309/6000 (71.81666493415833%)\n", + "Benign Test set: Avg. loss: 1.0897948344548543, Accuracy: 3557/6000 (59.28333401679993%)\n", + "Backdoor Test set: Avg. loss: 6.123890399932861, Accuracy: 347/6000 (5.783333256840706%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 8\n", + "Benign Train set: Avg. loss: 0.7948905934995794, Accuracy: 4361/6000 (72.68333435058594%)\n", + "Benign Test set: Avg. loss: 0.981172521909078, Accuracy: 3877/6000 (64.61666822433472%)\n", + "Backdoor Test set: Avg. loss: 6.38471794128418, Accuracy: 220/6000 (3.6666665226221085%)\n", + "Benign Train set: Avg. loss: 0.7767702095369076, Accuracy: 4268/6000 (71.13333344459534%)\n", + "Benign Test set: Avg. loss: 1.0797012249628704, Accuracy: 3685/6000 (61.41666769981384%)\n", + "Backdoor Test set: Avg. loss: 6.999324719111125, Accuracy: 64/6000 (1.066666655242443%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 8\n", + "Benign Train set: Avg. loss: 3.6896498552028167, Accuracy: 2017/6000 (33.6166650056839%)\n", + "Benign Test set: Avg. loss: 0.9923665126164755, Accuracy: 3849/6000 (64.14999961853027%)\n", + "Backdoor Test set: Avg. loss: 6.361936966578166, Accuracy: 206/6000 (3.4333333373069763%)\n", + "Benign Train set: Avg. loss: 0.5311627389585718, Accuracy: 4858/6000 (80.9666633605957%)\n", + "Benign Test set: Avg. loss: 1.159281611442566, Accuracy: 3404/6000 (56.73333406448364%)\n", + "Backdoor Test set: Avg. loss: 0.025155487780769665, Accuracy: 5983/6000 (99.7166633605957%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 8\n", + "Distance: cosine, use y 2: [1.4858135222716324, 0.977476786522586]\n", + "Distance: cosine, use x 2: [18.34126270588172, 0.977476786522586]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.11622237894780874, 0.07664657906707717]\n", + "Distance: euclid, use x 2: [20.604998755460542, 0.07664657906707717]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 8\n", + "Distance: cosine, use y 2: [0.9388586612372665, 1.714671650236454]\n", + "Distance: cosine, use x 2: [18.10483247218548, 0.9388586612372665]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.20490892995334242, 0.034104232400535395]\n", + "Distance: euclid, use x 2: [20.582603095299284, 0.034104232400535395]\n", + "Distance: euclid, use y 1: [10.954451150103337]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 8\n", + "Distance: cosine, use y 2: [0.22491384864525132, 5.580967454925001]\n", + "Distance: cosine, use x 2: [16.837961963884673, 0.22491384864525132]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.0554968008295349, 0.04022166114053061]\n", + "Distance: euclid, use x 2: [20.6323019044875, 0.04022166114053061]\n", + "Distance: euclid, use y 1: [10.954451150103326]\n", + "Distance: euclid, use x 2: [10.954451150103337, 0.0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 8\n", + "Distance: cosine, use y 2: [0.26587340013922356, 8.225043933814591]\n", + "Distance: cosine, use x 2: [12.71739294153999, 0.26587340013922345]\n", + "Distance: cosine, use y 1: [10.954451150103356]\n", + "Distance: cosine, use x 2: [10.954451150103317, 0.0]\n", + "Distance: euclid, use y 2: [1.3345927742622061, 4.716279015204606]\n", + "Distance: euclid, use x 2: [14.62993158975982, 1.3345927742622061]\n", + "Distance: euclid, use y 1: [10.9544511501031]\n", + "Distance: euclid, use x 2: [10.954451150103555, 0.0]\n", + "Suspicious Models detected by 3: [2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 9/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 9\n", + "Benign Train set: Avg. loss: 0.7385694509808053, Accuracy: 4414/6000 (73.56666326522827%)\n", + "Benign Test set: Avg. loss: 0.9833920300006866, Accuracy: 3851/6000 (64.18333053588867%)\n", + "Backdoor Test set: Avg. loss: 6.464086135228475, Accuracy: 237/6000 (3.9499998092651367%)\n", + "Benign Train set: Avg. loss: 0.7552795909503673, Accuracy: 4377/6000 (72.94999957084656%)\n", + "Benign Test set: Avg. loss: 1.1133271058400471, Accuracy: 3573/6000 (59.54999923706055%)\n", + "Backdoor Test set: Avg. loss: 6.551089922587077, Accuracy: 385/6000 (6.416666507720947%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 9\n", + "Benign Train set: Avg. loss: 0.7495020196792928, Accuracy: 4409/6000 (73.4833300113678%)\n", + "Benign Test set: Avg. loss: 1.0035707255204518, Accuracy: 3794/6000 (63.23333382606506%)\n", + "Backdoor Test set: Avg. loss: 6.403082211812337, Accuracy: 225/6000 (3.749999776482582%)\n", + "Benign Train set: Avg. loss: 0.8397732955661226, Accuracy: 4113/6000 (68.54999661445618%)\n", + "Benign Test set: Avg. loss: 1.2039979100227356, Accuracy: 3399/6000 (56.65000081062317%)\n", + "Backdoor Test set: Avg. loss: 6.2833296457926435, Accuracy: 567/6000 (9.449999779462814%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 9\n", + "Benign Train set: Avg. loss: 0.7583491218850967, Accuracy: 4420/6000 (73.66666793823242%)\n", + "Benign Test set: Avg. loss: 0.9650677740573883, Accuracy: 3915/6000 (65.24999737739563%)\n", + "Backdoor Test set: Avg. loss: 6.44249685605367, Accuracy: 240/6000 (3.999999910593033%)\n", + "Benign Train set: Avg. loss: 0.8237937175213023, Accuracy: 4164/6000 (69.40000057220459%)\n", + "Benign Test set: Avg. loss: 1.1611380378405254, Accuracy: 3523/6000 (58.71666669845581%)\n", + "Backdoor Test set: Avg. loss: 7.784638245900472, Accuracy: 103/6000 (1.7166666686534882%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 9\n", + "Benign Train set: Avg. loss: 3.7090179355854684, Accuracy: 2032/6000 (33.86666774749756%)\n", + "Benign Test set: Avg. loss: 0.9814696907997131, Accuracy: 3822/6000 (63.69999647140503%)\n", + "Backdoor Test set: Avg. loss: 6.440901358922322, Accuracy: 242/6000 (4.03333343565464%)\n", + "Benign Train set: Avg. loss: 0.5484541148105835, Accuracy: 4801/6000 (80.0166666507721%)\n", + "Benign Test set: Avg. loss: 1.2299280961354573, Accuracy: 3238/6000 (53.966665267944336%)\n", + "Backdoor Test set: Avg. loss: 0.010114844112346569, Accuracy: 5998/6000 (99.96666312217712%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 9\n", + "Distance: cosine, use y 2: [1.8038912335374206, 0.5874400448162644]\n", + "Distance: cosine, use x 2: [18.7090114435205, 0.587440044816264]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [5.243774056607994, 0.09955287848724703]\n", + "Distance: euclid, use x 2: [15.228444101490858, 0.09955287848724748]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 9\n", + "Distance: cosine, use y 2: [1.4631308833301224, 0.5653393127897361]\n", + "Distance: cosine, use x 2: [19.00888428867824, 0.5653393127897361]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [6.400049705541952, 0.2211771314311468]\n", + "Distance: euclid, use x 2: [0.22117713143114637, 14.5721913647447]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 9\n", + "Distance: cosine, use y 2: [0.339890961466641, 1.8003777279213722]\n", + "Distance: cosine, use x 2: [19.54458069073736, 0.339890961466641]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103326, 0.0]\n", + "Distance: euclid, use y 2: [1.6968000576584066, 0.18619853091320682]\n", + "Distance: euclid, use x 2: [18.781422037614547, 0.18619853091320593]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 9\n", + "Distance: cosine, use y 2: [0.6354703149698402, 6.208864660824278]\n", + "Distance: cosine, use x 2: [15.021103410811712, 0.6354703149698404]\n", + "Distance: cosine, use y 1: [10.95445115010333]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [2.6538531296435917, 1.3673352415188313]\n", + "Distance: euclid, use x 2: [17.134321602062876, 1.3673352415188313]\n", + "Distance: euclid, use y 1: [10.954451150103349]\n", + "Distance: euclid, use x 2: [10.954451150103319, 0.0]\n", + "Suspicious Models detected by 3: [2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "\n", + "Calling end\n", + "####################\n", + "All rounds completed successfully\n", + "####################\n", + "This is the end of the flow\n", + "####################\n" + ] + } + ], + "source": [ + "local_runtime = LocalRuntime(aggregator=aggregator_object, collaborators=collaborators)\n", + "\n", + "print(f\"Local runtime collaborators = {local_runtime.collaborators}\")\n", + "\n", + "# change to the internal flow loop\n", + "model = Net()\n", + "model.load_state_dict(pretrained_weights)\n", + "top_model_accuracy = 0\n", + "optimizers = {\n", + " collaborator.name: default_optimizer(model, optimizer_type=args.optimizer_type)\n", + " for collaborator in collaborators\n", + " }\n", + "flflow = FederatedFlow(\n", + " model,\n", + " optimizers,\n", + " device,\n", + " args.comm_round,\n", + " top_model_accuracy,\n", + " NUMBER_OF_MALICIOUS_CLIENTS / TOTAL_CLIENT_NUMBER,\n", + " 'CrowdGuard'\n", + " )\n", + "\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemoReduced.ipynb b/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemoReduced.ipynb new file mode 100644 index 00000000000..167bfe6fa17 --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/PoisoningAttackDemoReduced.ipynb @@ -0,0 +1,1443 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4bec0e77", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Aggregator step \"start\" registered\n", + "Collaborator step \"train\" registered\n", + "Aggregator step \"fed_avg_aggregation\" registered\n", + "Aggregator step \"collect_models\" registered\n", + "Collaborator step \"local_validation\" registered\n", + "Aggregator step \"defend\" registered\n", + "Aggregator step \"end\" registered\n" + ] + } + ], + "source": [ + "# Copyright (C) 2022-2024 TU Darmstadt\n", + "# SPDX-License-Identifier: Apache-2.0\n", + "\n", + "# -----------------------------------------------------------\n", + "# Primary author: Phillip Rieger \n", + "# Co-authored-by: Torsten Krauss \n", + "# ------------------------------------------------------------\n", + "\n", + "import argparse\n", + "import os\n", + "import pickle\n", + "import time\n", + "import warnings\n", + "from copy import deepcopy\n", + "from datetime import datetime\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import TensorDataset\n", + "import torch.optim as optim\n", + "from torchvision import transforms, datasets\n", + "from sklearn.cluster import AgglomerativeClustering, DBSCAN\n", + "\n", + "from CrowdGuardClientValidation import CrowdGuardClientValidation\n", + "from openfl.experimental.interface import Aggregator, Collaborator, FLSpec\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from cifar10_crowdguard import MEAN, STD_DEV, poison_data, seed_random_generators\n", + "from cifar10_crowdguard import BATCH_SIZE_TRAIN, BATCH_SIZE_TEST, Net, test, default_optimizer\n", + "from cifar10_crowdguard import FederatedFlow\n", + "from cifar10_crowdguard import PRETRAINED_MODEL_FILE, download_pretrained_model\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "94098847", + "metadata": {}, + "outputs": [], + "source": [ + "TOTAL_CLIENT_NUMBER = 4\n", + "PMR = 0.25\n", + "NUMBER_OF_MALICIOUS_CLIENTS = max(1, int(TOTAL_CLIENT_NUMBER * PMR)) if PMR > 0 else 0\n", + "NUMBER_OF_BENIGN_CLIENTS = TOTAL_CLIENT_NUMBER - NUMBER_OF_MALICIOUS_CLIENTS\n", + "NUMBER_OF_ROUNDS = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0107812", + "metadata": {}, + "outputs": [], + "source": [ + "class CommandLineArgumentSimulator:\n", + " \n", + " def __init__(self):\n", + " self.test_dataset_ratio = 0.4\n", + " self.train_dataset_ratio = 0.4\n", + " self.log_dir = 'test_debug'\n", + " self.comm_round = NUMBER_OF_ROUNDS\n", + " self.flow_internal_loop_test=False\n", + " self.optimizer_type = 'SGD'\n", + " \n", + "args = CommandLineArgumentSimulator()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fd18c0b3", + "metadata": {}, + "outputs": [], + "source": [ + "download_pretrained_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5d8950eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "aggregator_object = Aggregator()\n", + "aggregator_object.private_attributes = {}\n", + "collaborator_names = [f'benign_{i:02d}' for i in range(NUMBER_OF_BENIGN_CLIENTS)] + [f'malicious_{i:02d}' for i in range(NUMBER_OF_MALICIOUS_CLIENTS)] \n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\n", + " \"cuda:1\"\n", + " ) # This will enable Ray library to reserve available GPU(s) for the task\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD_DEV),])\n", + "\n", + "cifar_train = datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n", + "cifar_train = [x for x in cifar_train]\n", + "cifar_test = datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\n", + "cifar_test = [x for x in cifar_test]\n", + "X = torch.stack([x[0] for x in cifar_train] + [x[0] for x in cifar_test])\n", + "Y = torch.LongTensor(np.stack(np.array([x[1] for x in cifar_train] + [x[1] for x in cifar_test])))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e92f0205", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset info (total 60000): train - 24000, test - 24000, \n" + ] + } + ], + "source": [ + "seed_random_generators(0)\n", + "shuffled_indices = np.arange(X.shape[0])\n", + "np.random.shuffle(shuffled_indices)\n", + "\n", + "N_total_samples = len(cifar_test) + len(cifar_train)\n", + "train_dataset_size = int(N_total_samples * args.train_dataset_ratio)\n", + "test_dataset_size = int(N_total_samples * args.test_dataset_ratio)\n", + "X = X[shuffled_indices]\n", + "Y = Y[shuffled_indices]\n", + "\n", + "train_dataset_data = X[:train_dataset_size]\n", + "train_dataset_targets = Y[:train_dataset_size]\n", + "\n", + "test_dataset_data = X[train_dataset_size:train_dataset_size + test_dataset_size]\n", + "test_dataset_targets = Y[train_dataset_size:train_dataset_size + test_dataset_size]\n", + "print(f\"Dataset info (total {N_total_samples}): train - {test_dataset_targets.shape[0]}, \"\n", + " f\"test - {test_dataset_targets.shape[0]}, \")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d47aca7e", + "metadata": {}, + "outputs": [], + "source": [ + "for idx, collab in enumerate(collaborators):\n", + " # construct the training and test and population dataset\n", + " benign_training_X = train_dataset_data[idx::len(collaborators)]\n", + " benign_training_Y = train_dataset_targets[idx::len(collaborators)]\n", + " \n", + " if 'malicious' in collab.name:\n", + " local_train_data, local_train_targets = poison_data(benign_training_X, benign_training_Y)\n", + " else:\n", + " local_train_data, local_train_targets = benign_training_X, benign_training_Y\n", + " \n", + "\n", + " local_test_data = test_dataset_data[idx::len(collaborators)]\n", + " local_test_targets = test_dataset_targets[idx::len(collaborators)]\n", + " \n", + "\n", + " poison_test_data, poison_test_targets = poison_data(local_test_data, local_test_targets,\n", + " pdr=1.0)\n", + "\n", + " collab.private_attributes = {\n", + " \"train_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(local_train_data, local_train_targets),\n", + " batch_size=BATCH_SIZE_TRAIN, shuffle=True\n", + " ),\n", + " \"test_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(local_test_data, local_test_targets),\n", + " batch_size=BATCH_SIZE_TEST, shuffle=False\n", + " ),\n", + " \"backdoor_test_loader\": torch.utils.data.DataLoader(\n", + " TensorDataset(poison_test_data, poison_test_targets),\n", + " batch_size=BATCH_SIZE_TEST, shuffle=False\n", + " ),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "16c46575", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benign Train set: Avg. loss: 3.3837500759895813, Accuracy: 2083/6000 (34.717%)\n", + "Benign Test set: Avg. loss: 0.9973345994949341, Accuracy: 3768/6000 (62.800%)\n", + "Backdoor Test set: Avg. loss: 5.72957197825114, Accuracy: 325/6000 (5.417%)\n" + ] + }, + { + "data": { + "text/plain": [ + "0.05416666716337204" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pretrained_weights = torch.load(PRETRAINED_MODEL_FILE, map_location=device)\n", + "test_model = Net().to(device)\n", + "test_model.load_state_dict(pretrained_weights)\n", + "test(test_model, collab.private_attributes['train_loader'], device, test_train='Train')\n", + "test(test_model, collab.private_attributes['test_loader'], device)\n", + "test(test_model, collab.private_attributes['backdoor_test_loader'], device, mode='Backdoor')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5e9721c1", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Local runtime collaborators = ['benign_00', 'benign_01', 'benign_02', 'malicious_00']\n", + "####################\n", + "Round 0...\n", + "####################\n", + "\n", + "Calling start\n", + "Performing initialization for model\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 0\n", + "Benign Train set: Avg. loss: 0.9885769543495584, Accuracy: 3790/6000 (63.16666603088379%)\n", + "Benign Test set: Avg. loss: 1.001497248808543, Accuracy: 3761/6000 (62.68333196640015%)\n", + "Backdoor Test set: Avg. loss: 5.7991689046223955, Accuracy: 330/6000 (5.499999970197678%)\n", + "Benign Train set: Avg. loss: 1.0561292523399313, Accuracy: 3597/6000 (59.950000047683716%)\n", + "Benign Test set: Avg. loss: 1.197381854057312, Accuracy: 3327/6000 (55.44999837875366%)\n", + "Backdoor Test set: Avg. loss: 5.618184248606364, Accuracy: 302/6000 (5.0333332270383835%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 0\n", + "Benign Train set: Avg. loss: 1.0144326242994754, Accuracy: 3744/6000 (62.40000128746033%)\n", + "Benign Test set: Avg. loss: 1.0136706431706746, Accuracy: 3713/6000 (61.88333034515381%)\n", + "Backdoor Test set: Avg. loss: 5.686683019002278, Accuracy: 315/6000 (5.249999836087227%)\n", + "Benign Train set: Avg. loss: 1.0219575502771012, Accuracy: 3705/6000 (61.75000071525574%)\n", + "Benign Test set: Avg. loss: 1.168825924396515, Accuracy: 3347/6000 (55.78333139419556%)\n", + "Backdoor Test set: Avg. loss: 4.57614795366923, Accuracy: 514/6000 (8.566666394472122%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 0\n", + "Benign Train set: Avg. loss: 0.9901170968375308, Accuracy: 3809/6000 (63.483333587646484%)\n", + "Benign Test set: Avg. loss: 0.9752019345760345, Accuracy: 3854/6000 (64.23333287239075%)\n", + "Backdoor Test set: Avg. loss: 5.79969318707784, Accuracy: 314/6000 (5.233333259820938%)\n", + "Benign Train set: Avg. loss: 0.9862309634051425, Accuracy: 3803/6000 (63.38333487510681%)\n", + "Benign Test set: Avg. loss: 1.1106985807418823, Accuracy: 3549/6000 (59.14999842643738%)\n", + "Backdoor Test set: Avg. loss: 6.382502635320027, Accuracy: 96/6000 (1.5999998897314072%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 0\n", + "Benign Train set: Avg. loss: 3.3836448154550918, Accuracy: 2083/6000 (34.716665744781494%)\n", + "Benign Test set: Avg. loss: 0.9973345994949341, Accuracy: 3768/6000 (62.800002098083496%)\n", + "Backdoor Test set: Avg. loss: 5.72957197825114, Accuracy: 325/6000 (5.416666716337204%)\n", + "Benign Train set: Avg. loss: 0.5883342332820943, Accuracy: 4751/6000 (79.18333411216736%)\n", + "Benign Test set: Avg. loss: 1.2542604207992554, Accuracy: 3293/6000 (54.883331060409546%)\n", + "Backdoor Test set: Avg. loss: 0.03580721157292525, Accuracy: 5985/6000 (99.75000023841858%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 0\n", + "Distance: cosine, use y 2: [3.139687311456905, 0.24390152941851007]\n", + "Distance: cosine, use x 2: [18.736333963895376, 0.2439015294185105]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.037936255947664144, 0.08153858333130248]\n", + "Distance: euclid, use x 2: [20.61820415899534, 0.037936255947664144]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 0\n", + "Distance: cosine, use y 2: [2.52111887454884, 0.3429832093794909]\n", + "Distance: cosine, use x 2: [19.02273590470069, 0.34298320937949045]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.05887220497062895, 0.4755074262107195]\n", + "Distance: euclid, use x 2: [20.471009787948905, 0.05887220497062895]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 0\n", + "Distance: cosine, use y 2: [1.5919269464334604, 0.696426347466276]\n", + "Distance: cosine, use x 2: [19.253468984032075, 0.696426347466276]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.954451150103331, 0.0]\n", + "Distance: euclid, use y 2: [0.0473784673111064, 0.6217562037526347]\n", + "Distance: euclid, use x 2: [20.414253777149813, 0.04737846731110551]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 0\n", + "Distance: cosine, use y 2: [1.586522424378272, 3.622045633352822]\n", + "Distance: cosine, use x 2: [15.88878766536782, 1.5865224243782723]\n", + "Distance: cosine, use y 1: [10.954451150103303]\n", + "Distance: cosine, use x 2: [10.954451150103333, 0.0]\n", + "Distance: euclid, use y 2: [1.3087604961220998, 2.2733504384252856]\n", + "Distance: euclid, use x 2: [16.46799567492959, 1.3087604961220998]\n", + "Distance: euclid, use y 1: [10.954451150103305]\n", + "Distance: euclid, use x 2: [10.954451150103354, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 1/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 1\n", + "Benign Train set: Avg. loss: 0.9778860347702149, Accuracy: 3858/6000 (64.30000066757202%)\n", + "Benign Test set: Avg. loss: 1.0275444785753887, Accuracy: 3708/6000 (61.799997091293335%)\n", + "Backdoor Test set: Avg. loss: 5.126115083694458, Accuracy: 369/6000 (6.149999797344208%)\n", + "Benign Train set: Avg. loss: 1.055129660887921, Accuracy: 3531/6000 (58.85000228881836%)\n", + "Benign Test set: Avg. loss: 1.2361122767130535, Accuracy: 3223/6000 (53.716665506362915%)\n", + "Backdoor Test set: Avg. loss: 6.004116376241048, Accuracy: 266/6000 (4.43333312869072%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 1\n", + "Benign Train set: Avg. loss: 0.9974439695794531, Accuracy: 3841/6000 (64.01666402816772%)\n", + "Benign Test set: Avg. loss: 1.050683617591858, Accuracy: 3690/6000 (61.500000953674316%)\n", + "Backdoor Test set: Avg. loss: 5.031114816665649, Accuracy: 361/6000 (6.016666442155838%)\n", + "Benign Train set: Avg. loss: 1.0382162418137206, Accuracy: 3697/6000 (61.61666512489319%)\n", + "Benign Test set: Avg. loss: 1.172116796175639, Accuracy: 3390/6000 (56.49999976158142%)\n", + "Backdoor Test set: Avg. loss: 5.203440030415853, Accuracy: 127/6000 (2.1166667342185974%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 1\n", + "Benign Train set: Avg. loss: 0.9782646976886912, Accuracy: 3908/6000 (65.13333320617676%)\n", + "Benign Test set: Avg. loss: 1.0082112451394398, Accuracy: 3819/6000 (63.65000009536743%)\n", + "Backdoor Test set: Avg. loss: 5.111261924107869, Accuracy: 341/6000 (5.6833334267139435%)\n", + "Benign Train set: Avg. loss: 1.0318745732941526, Accuracy: 3704/6000 (61.73333525657654%)\n", + "Benign Test set: Avg. loss: 1.1651872595151265, Accuracy: 3457/6000 (57.616668939590454%)\n", + "Backdoor Test set: Avg. loss: 6.8470542430877686, Accuracy: 20/6000 (0.33333334140479565%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 1\n", + "Benign Train set: Avg. loss: 3.0500604871739734, Accuracy: 2085/6000 (34.74999964237213%)\n", + "Benign Test set: Avg. loss: 1.0224250356356304, Accuracy: 3771/6000 (62.849998474121094%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backdoor Test set: Avg. loss: 5.079943021138509, Accuracy: 371/6000 (6.183333322405815%)\n", + "Benign Train set: Avg. loss: 0.5856330990791321, Accuracy: 4701/6000 (78.35000157356262%)\n", + "Benign Test set: Avg. loss: 1.2663490772247314, Accuracy: 3113/6000 (51.883333921432495%)\n", + "Backdoor Test set: Avg. loss: 0.023603687683741253, Accuracy: 5993/6000 (99.88332986831665%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 1\n", + "Distance: cosine, use y 2: [4.425570297924237, 0.3427598443292488]\n", + "Distance: cosine, use x 2: [16.344858266211432, 0.34275984432924833]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.19457190016820558, 0.019510296358126844]\n", + "Distance: euclid, use x 2: [20.585892915372213, 0.019510296358126844]\n", + "Distance: euclid, use y 1: [10.95445115010333]\n", + "Distance: euclid, use x 2: [10.954451150103331, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 1\n", + "Distance: cosine, use y 2: [2.801438809103918, 0.9144500874639174]\n", + "Distance: cosine, use x 2: [17.0218502598279, 0.914450087463917]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.15236963958841976, 0.008001295032160627]\n", + "Distance: euclid, use x 2: [20.60343664068541, 0.008001295032160627]\n", + "Distance: euclid, use y 1: [10.954451150103315]\n", + "Distance: euclid, use x 2: [10.954451150103319, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 1\n", + "Distance: cosine, use y 2: [0.8950936094280646, 3.8962927324545014]\n", + "Distance: cosine, use x 2: [14.828649175189087, 0.895093609428065]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.13459560982305163, 0.04044234360506671]\n", + "Distance: euclid, use x 2: [20.602988399691313, 0.04044234360506582]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 1\n", + "Distance: cosine, use y 2: [0.10531291572774401, 9.719286605534212]\n", + "Distance: cosine, use x 2: [10.18847609917082, 0.10531291572774402]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.5068797928668665, 9.141204145247993]\n", + "Distance: euclid, use x 2: [12.191430763432086, 0.5068797928668665]\n", + "Distance: euclid, use y 1: [10.954451150103322]\n", + "Distance: euclid, use x 2: [10.954451150103337, 0.0]\n", + "Suspicious Models detected by 3: [0]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 2/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 2\n", + "Benign Train set: Avg. loss: 0.9301359200097145, Accuracy: 3988/6000 (66.46666526794434%)\n", + "Benign Test set: Avg. loss: 1.0193569262822468, Accuracy: 3759/6000 (62.65000104904175%)\n", + "Backdoor Test set: Avg. loss: 5.6510575612386065, Accuracy: 197/6000 (3.283333405852318%)\n", + "Benign Train set: Avg. loss: 0.8698681494657029, Accuracy: 4072/6000 (67.86666512489319%)\n", + "Benign Test set: Avg. loss: 1.0671072800954182, Accuracy: 3640/6000 (60.66666841506958%)\n", + "Backdoor Test set: Avg. loss: 6.187037706375122, Accuracy: 171/6000 (2.850000001490116%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 2\n", + "Benign Train set: Avg. loss: 0.9605293927040506, Accuracy: 3932/6000 (65.53333401679993%)\n", + "Benign Test set: Avg. loss: 1.0264412760734558, Accuracy: 3763/6000 (62.71666884422302%)\n", + "Backdoor Test set: Avg. loss: 5.560268004735311, Accuracy: 186/6000 (3.099999949336052%)\n", + "Benign Train set: Avg. loss: 0.9960020336698978, Accuracy: 3756/6000 (62.59999871253967%)\n", + "Benign Test set: Avg. loss: 1.1556094487508137, Accuracy: 3446/6000 (57.43333101272583%)\n", + "Backdoor Test set: Avg. loss: 5.153351465861003, Accuracy: 652/6000 (10.866666585206985%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 2\n", + "Benign Train set: Avg. loss: 0.9454977268234213, Accuracy: 3980/6000 (66.33332967758179%)\n", + "Benign Test set: Avg. loss: 1.0062820414702098, Accuracy: 3797/6000 (63.28333020210266%)\n", + "Backdoor Test set: Avg. loss: 5.6208906173706055, Accuracy: 200/6000 (3.333333134651184%)\n", + "Benign Train set: Avg. loss: 0.9145594096564232, Accuracy: 4006/6000 (66.76666736602783%)\n", + "Benign Test set: Avg. loss: 1.064531107743581, Accuracy: 3645/6000 (60.750001668930054%)\n", + "Backdoor Test set: Avg. loss: 6.3380054632822675, Accuracy: 148/6000 (2.4666666984558105%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 2\n", + "Benign Train set: Avg. loss: 3.3212531792356614, Accuracy: 2006/6000 (33.43333303928375%)\n", + "Benign Test set: Avg. loss: 1.0044313569863637, Accuracy: 3837/6000 (63.95000219345093%)\n", + "Backdoor Test set: Avg. loss: 5.585842609405518, Accuracy: 199/6000 (3.3166665583848953%)\n", + "Benign Train set: Avg. loss: 0.5776876041546781, Accuracy: 4669/6000 (77.8166651725769%)\n", + "Benign Test set: Avg. loss: 1.275720755259196, Accuracy: 3109/6000 (51.81666612625122%)\n", + "Backdoor Test set: Avg. loss: 0.011549354065209627, Accuracy: 5986/6000 (99.76666569709778%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 2\n", + "Distance: cosine, use y 2: [2.090019121577491, 0.5250757335236536]\n", + "Distance: cosine, use x 2: [19.14903846235277, 0.525075733523654]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.05904294188861314, 0.028922285540788906]\n", + "Distance: euclid, use x 2: [20.63419646946428, 0.028922285540788906]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 2\n", + "Distance: cosine, use y 2: [0.2429033069883788, 1.2457802166893739]\n", + "Distance: cosine, use x 2: [20.09467701539262, 0.24290330698837792]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.0344511006636834, 0.046264541991975605]\n", + "Distance: euclid, use x 2: [20.636999351002657, 0.0344511006636834]\n", + "Distance: euclid, use y 1: [10.954451150103335]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 2\n", + "Distance: cosine, use y 2: [0.46303665011333894, 1.5631590460968328]\n", + "Distance: cosine, use x 2: [19.683905163520063, 0.46303665011333983]\n", + "Distance: cosine, use y 1: [10.954451150103312]\n", + "Distance: cosine, use x 2: [10.954451150103319, 0.0]\n", + "Distance: euclid, use y 2: [0.029072666051598084, 0.03096571611546306]\n", + "Distance: euclid, use x 2: [20.64375620079601, 0.029072666051598084]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distance: cosine, use y 2: [8.731105394391083, 0.12413702567453089]\n", + "Distance: cosine, use x 2: [13.670805247683036, 0.12413702567453089]\n", + "Distance: cosine, use y 1: [10.954451150103324]\n", + "Distance: cosine, use x 2: [10.954451150103342, 0.0]\n", + "Distance: euclid, use y 2: [2.465981367017415, 1.0990384519078722]\n", + "Distance: euclid, use x 2: [17.190112051993644, 1.0990384519078722]\n", + "Distance: euclid, use y 1: [10.954451150103347]\n", + "Distance: euclid, use x 2: [10.954451150103319, 0.0]\n", + "Suspicious Models detected by 3: [1]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 3/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 3\n", + "Benign Train set: Avg. loss: 0.8797781603767517, Accuracy: 4098/6000 (68.29999685287476%)\n", + "Benign Test set: Avg. loss: 0.988082766532898, Accuracy: 3858/6000 (64.30000066757202%)\n", + "Backdoor Test set: Avg. loss: 5.625429153442383, Accuracy: 357/6000 (5.949999764561653%)\n", + "Benign Train set: Avg. loss: 0.8560976173649443, Accuracy: 4138/6000 (68.96666884422302%)\n", + "Benign Test set: Avg. loss: 1.0846137603123982, Accuracy: 3643/6000 (60.71666479110718%)\n", + "Backdoor Test set: Avg. loss: 5.940346876780192, Accuracy: 266/6000 (4.43333312869072%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 3\n", + "Benign Train set: Avg. loss: 0.9047486468198451, Accuracy: 4064/6000 (67.73333549499512%)\n", + "Benign Test set: Avg. loss: 0.997472216685613, Accuracy: 3839/6000 (63.983333110809326%)\n", + "Backdoor Test set: Avg. loss: 5.525551478068034, Accuracy: 349/6000 (5.816666781902313%)\n", + "Benign Train set: Avg. loss: 0.912026990601357, Accuracy: 3910/6000 (65.16666412353516%)\n", + "Benign Test set: Avg. loss: 1.1148398319880168, Accuracy: 3486/6000 (58.09999704360962%)\n", + "Backdoor Test set: Avg. loss: 5.849345684051514, Accuracy: 313/6000 (5.216666683554649%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 3\n", + "Benign Train set: Avg. loss: 0.8902336378046807, Accuracy: 4134/6000 (68.90000104904175%)\n", + "Benign Test set: Avg. loss: 0.9691097537676493, Accuracy: 3901/6000 (65.01666307449341%)\n", + "Backdoor Test set: Avg. loss: 5.609223286310832, Accuracy: 354/6000 (5.900000035762787%)\n", + "Benign Train set: Avg. loss: 0.9305718002167154, Accuracy: 3965/6000 (66.08332991600037%)\n", + "Benign Test set: Avg. loss: 1.1298276980717976, Accuracy: 3553/6000 (59.21666622161865%)\n", + "Backdoor Test set: Avg. loss: 5.478763898213704, Accuracy: 385/6000 (6.416666507720947%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 3\n", + "Benign Train set: Avg. loss: 3.2841210809159787, Accuracy: 2122/6000 (35.366666316986084%)\n", + "Benign Test set: Avg. loss: 0.9808564384778341, Accuracy: 3840/6000 (63.999998569488525%)\n", + "Backdoor Test set: Avg. loss: 5.563961982727051, Accuracy: 359/6000 (5.9833332896232605%)\n", + "Benign Train set: Avg. loss: 0.584124024085542, Accuracy: 4729/6000 (78.81666421890259%)\n", + "Benign Test set: Avg. loss: 1.2889393369356792, Accuracy: 3182/6000 (53.03333401679993%)\n", + "Backdoor Test set: Avg. loss: 0.01501223910599947, Accuracy: 5996/6000 (99.93333220481873%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 3\n", + "Distance: cosine, use y 2: [0.6897721817120037, 2.2868132657080324]\n", + "Distance: cosine, use x 2: [18.038688763229835, 0.6897721817120037]\n", + "Distance: cosine, use y 1: [10.954451150103337]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [0.07893401250770271, 0.13630794113106504]\n", + "Distance: euclid, use x 2: [20.52155186630494, 0.0789340125077036]\n", + "Distance: euclid, use y 1: [10.954451150103313]\n", + "Distance: euclid, use x 2: [10.954451150103322, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 3\n", + "Distance: cosine, use y 2: [0.5250672293876582, 2.7721144925007435]\n", + "Distance: cosine, use x 2: [18.1043821779655, 0.5250672293876577]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.008386279013575582, 0.4986609520378167]\n", + "Distance: euclid, use x 2: [20.21018883810196, 0.008386279013574693]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 3\n", + "Distance: cosine, use y 2: [0.9376233056986072, 0.5526386053609924]\n", + "Distance: cosine, use x 2: [19.144297467758847, 0.5526386053609933]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.07042758809094973, 0.41313934452415424]\n", + "Distance: euclid, use x 2: [20.470667882786586, 0.07042758809094973]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 3\n", + "Distance: cosine, use y 2: [0.7801455877138592, 6.476302880460576]\n", + "Distance: cosine, use x 2: [13.233120749245378, 0.7801455877138593]\n", + "Distance: cosine, use y 1: [10.954451150103345]\n", + "Distance: cosine, use x 2: [10.954451150103326, 0.0]\n", + "Distance: euclid, use y 2: [2.8989432626992415, 1.4209623673912621]\n", + "Distance: euclid, use x 2: [15.597565528663878, 1.4209623673912624]\n", + "Distance: euclid, use y 1: [10.954451150103294]\n", + "Distance: euclid, use x 2: [10.954451150103367, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 4/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 4\n", + "Benign Train set: Avg. loss: 0.8473593151949822, Accuracy: 4201/6000 (70.0166642665863%)\n", + "Benign Test set: Avg. loss: 0.9812094370524088, Accuracy: 3853/6000 (64.21666741371155%)\n", + "Backdoor Test set: Avg. loss: 5.526396592458089, Accuracy: 362/6000 (6.033333390951157%)\n", + "Benign Train set: Avg. loss: 0.8603191133192245, Accuracy: 4075/6000 (67.91666746139526%)\n", + "Benign Test set: Avg. loss: 1.12581870953242, Accuracy: 3511/6000 (58.51666331291199%)\n", + "Backdoor Test set: Avg. loss: 6.302475372950236, Accuracy: 324/6000 (5.399999767541885%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 4\n", + "Benign Train set: Avg. loss: 0.8707632777538705, Accuracy: 4182/6000 (69.69999670982361%)\n", + "Benign Test set: Avg. loss: 0.9912213583787283, Accuracy: 3871/6000 (64.51666355133057%)\n", + "Backdoor Test set: Avg. loss: 5.441365718841553, Accuracy: 378/6000 (6.300000101327896%)\n", + "Benign Train set: Avg. loss: 0.929772324384527, Accuracy: 3987/6000 (66.44999980926514%)\n", + "Benign Test set: Avg. loss: 1.139621078968048, Accuracy: 3508/6000 (58.46666693687439%)\n", + "Backdoor Test set: Avg. loss: 5.0033665498097735, Accuracy: 499/6000 (8.316666632890701%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 4\n", + "Benign Train set: Avg. loss: 0.8620590253713283, Accuracy: 4236/6000 (70.59999704360962%)\n", + "Benign Test set: Avg. loss: 0.9697187145551046, Accuracy: 3954/6000 (65.89999794960022%)\n", + "Backdoor Test set: Avg. loss: 5.5111550490061445, Accuracy: 367/6000 (6.11666664481163%)\n", + "Benign Train set: Avg. loss: 0.8872369096634236, Accuracy: 4138/6000 (68.96666884422302%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benign Test set: Avg. loss: 1.0738922754923503, Accuracy: 3626/6000 (60.43333411216736%)\n", + "Backdoor Test set: Avg. loss: 5.036426067352295, Accuracy: 421/6000 (7.016666233539581%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 4\n", + "Benign Train set: Avg. loss: 3.241051955425993, Accuracy: 2149/6000 (35.81666648387909%)\n", + "Benign Test set: Avg. loss: 0.9790957868099213, Accuracy: 3860/6000 (64.33333158493042%)\n", + "Backdoor Test set: Avg. loss: 5.473701159159343, Accuracy: 386/6000 (6.433333456516266%)\n", + "Benign Train set: Avg. loss: 0.5482086730288699, Accuracy: 4775/6000 (79.58333492279053%)\n", + "Benign Test set: Avg. loss: 1.1979321042696636, Accuracy: 3244/6000 (54.06666398048401%)\n", + "Backdoor Test set: Avg. loss: 0.012110005132853985, Accuracy: 5992/6000 (99.86666440963745%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 4\n", + "Distance: cosine, use y 2: [1.6868625176710177, 1.0355540140924653]\n", + "Distance: cosine, use x 2: [18.332896012058793, 1.0355540140924657]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.015551310552236686, 0.20436053166930535]\n", + "Distance: euclid, use x 2: [20.464638642403706, 0.015551310552236686]\n", + "Distance: euclid, use y 1: [10.954451150103313]\n", + "Distance: euclid, use x 2: [10.954451150103315, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 4\n", + "Distance: cosine, use y 2: [1.9480474991983874, 0.9100573433205303]\n", + "Distance: cosine, use x 2: [17.34332939260181, 0.9100573433205308]\n", + "Distance: cosine, use y 1: [10.954451150103345]\n", + "Distance: cosine, use x 2: [10.954451150103326, 0.0]\n", + "Distance: euclid, use y 2: [0.09394431381224333, 0.4889495857715147]\n", + "Distance: euclid, use x 2: [20.464938500823514, 0.09394431381224333]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 4\n", + "Distance: cosine, use y 2: [1.3971158837628077, 0.8438158134977676]\n", + "Distance: cosine, use x 2: [17.321712581324253, 0.8438158134977676]\n", + "Distance: cosine, use y 1: [10.954451150103315]\n", + "Distance: cosine, use x 2: [10.954451150103322, 0.0]\n", + "Distance: euclid, use y 2: [0.20411737313829725, 0.38046902129413773]\n", + "Distance: euclid, use x 2: [20.171324137508435, 0.20411737313829725]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 4\n", + "Distance: cosine, use y 2: [9.492243684847173, 0.3191188707738067]\n", + "Distance: cosine, use x 2: [11.253934367112079, 0.3191188707738067]\n", + "Distance: cosine, use y 1: [10.954451150103333]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [2.1653816532140224, 0.9962351020896669]\n", + "Distance: euclid, use x 2: [17.690340045161005, 0.9962351020896669]\n", + "Distance: euclid, use y 1: [10.954451150103129]\n", + "Distance: euclid, use x 2: [10.954451150103536, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 5/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 5\n", + "Benign Train set: Avg. loss: 0.8537916736082828, Accuracy: 4182/6000 (69.69999670982361%)\n", + "Benign Test set: Avg. loss: 1.0111433267593384, Accuracy: 3784/6000 (63.066667318344116%)\n", + "Backdoor Test set: Avg. loss: 5.398749272028605, Accuracy: 420/6000 (7.000000029802322%)\n", + "Benign Train set: Avg. loss: 0.874724712460599, Accuracy: 4125/6000 (68.75%)\n", + "Benign Test set: Avg. loss: 1.1135798692703247, Accuracy: 3577/6000 (59.61666703224182%)\n", + "Backdoor Test set: Avg. loss: 5.225924889246623, Accuracy: 441/6000 (7.3499999940395355%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 5\n", + "Benign Train set: Avg. loss: 0.8769010108836154, Accuracy: 4142/6000 (69.03333067893982%)\n", + "Benign Test set: Avg. loss: 1.0202871362368267, Accuracy: 3776/6000 (62.93333172798157%)\n", + "Backdoor Test set: Avg. loss: 5.315138339996338, Accuracy: 422/6000 (7.0333331823349%)\n", + "Benign Train set: Avg. loss: 0.8616926990290905, Accuracy: 4124/6000 (68.7333345413208%)\n", + "Benign Test set: Avg. loss: 1.095106561978658, Accuracy: 3566/6000 (59.433335065841675%)\n", + "Backdoor Test set: Avg. loss: 5.312667687733968, Accuracy: 452/6000 (7.533333450555801%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 5\n", + "Benign Train set: Avg. loss: 0.873329536077824, Accuracy: 4199/6000 (69.9833333492279%)\n", + "Benign Test set: Avg. loss: 0.9910587867101034, Accuracy: 3871/6000 (64.51666355133057%)\n", + "Backdoor Test set: Avg. loss: 5.387445529301961, Accuracy: 407/6000 (6.783333420753479%)\n", + "Benign Train set: Avg. loss: 0.9390684058691593, Accuracy: 3925/6000 (65.41666388511658%)\n", + "Benign Test set: Avg. loss: 1.1703458031018574, Accuracy: 3487/6000 (58.116668462753296%)\n", + "Backdoor Test set: Avg. loss: 6.923852841059367, Accuracy: 156/6000 (2.6000000536441803%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 5\n", + "Benign Train set: Avg. loss: 3.1833345179862165, Accuracy: 2108/6000 (35.13333201408386%)\n", + "Benign Test set: Avg. loss: 1.0030378798643749, Accuracy: 3794/6000 (63.23333382606506%)\n", + "Backdoor Test set: Avg. loss: 5.3597166538238525, Accuracy: 426/6000 (7.100000232458115%)\n", + "Benign Train set: Avg. loss: 0.5394560545682907, Accuracy: 4862/6000 (81.03333115577698%)\n", + "Benign Test set: Avg. loss: 1.1633540789286296, Accuracy: 3418/6000 (56.966668367385864%)\n", + "Backdoor Test set: Avg. loss: 0.03771169111132622, Accuracy: 5942/6000 (99.03333187103271%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 5\n", + "Distance: cosine, use y 2: [0.32204569057429655, 2.0679682501473744]\n", + "Distance: cosine, use x 2: [18.673222172064634, 0.32204569057429566]\n", + "Distance: cosine, use y 1: [10.954451150103342]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.7952719011254965, 0.062112622190042543]\n", + "Distance: euclid, use x 2: [20.22204485928068, 0.062112622190042543]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 5\n", + "Distance: cosine, use y 2: [2.0462303005952664, 1.179787399308342]\n", + "Distance: cosine, use x 2: [17.29438610427769, 1.1797873993083416]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.5940232587587717, 0.0913571067874166]\n", + "Distance: euclid, use x 2: [20.41687447361079, 0.09135710678741749]\n", + "Distance: euclid, use y 1: [10.954451150103313]\n", + "Distance: euclid, use x 2: [10.954451150103315, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 5\n", + "Distance: cosine, use y 2: [2.452172980470925, 1.0277639087777235]\n", + "Distance: cosine, use x 2: [17.601670629695057, 1.0277639087777235]\n", + "Distance: cosine, use y 1: [10.954451150103289]\n", + "Distance: cosine, use x 2: [10.954451150103381, 0.0]\n", + "Distance: euclid, use y 2: [0.03135712597971185, 0.31858317297936445]\n", + "Distance: euclid, use x 2: [20.498127626525033, 0.031357125979710965]\n", + "Distance: euclid, use y 1: [10.954451150103337]\n", + "Distance: euclid, use x 2: [10.954451150103331, 0.0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 5\n", + "Distance: cosine, use y 2: [2.2935255570229356, 3.2583348201762328]\n", + "Distance: cosine, use x 2: [14.272313280144168, 2.2935255570229356]\n", + "Distance: cosine, use y 1: [10.954451150103312]\n", + "Distance: cosine, use x 2: [10.954451150103326, 0.0]\n", + "Distance: euclid, use y 2: [3.7317955424528026, 0.441976840267575]\n", + "Distance: euclid, use x 2: [17.495995482338625, 0.441976840267575]\n", + "Distance: euclid, use y 1: [10.954451150103061]\n", + "Distance: euclid, use x 2: [10.954451150103594, 0.0]\n", + "Suspicious Models detected by 3: [0, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 6/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 6\n", + "Benign Train set: Avg. loss: 0.8286352633161748, Accuracy: 4201/6000 (70.0166642665863%)\n", + "Benign Test set: Avg. loss: 0.9954776068528494, Accuracy: 3803/6000 (63.38333487510681%)\n", + "Backdoor Test set: Avg. loss: 5.628017028172811, Accuracy: 370/6000 (6.166666746139526%)\n", + "Benign Train set: Avg. loss: 0.8233681094456227, Accuracy: 4156/6000 (69.26666498184204%)\n", + "Benign Test set: Avg. loss: 1.1489962935447693, Accuracy: 3514/6000 (58.56666564941406%)\n", + "Backdoor Test set: Avg. loss: 7.499392986297607, Accuracy: 150/6000 (2.500000037252903%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 6\n", + "Benign Train set: Avg. loss: 0.8520627500528984, Accuracy: 4189/6000 (69.81666684150696%)\n", + "Benign Test set: Avg. loss: 1.0100693106651306, Accuracy: 3760/6000 (62.66666650772095%)\n", + "Backdoor Test set: Avg. loss: 5.5396309693654375, Accuracy: 362/6000 (6.033333390951157%)\n", + "Benign Train set: Avg. loss: 0.8405023550099515, Accuracy: 4123/6000 (68.71666312217712%)\n", + "Benign Test set: Avg. loss: 1.116740624109904, Accuracy: 3513/6000 (58.55000019073486%)\n", + "Backdoor Test set: Avg. loss: 4.886253118515015, Accuracy: 796/6000 (13.266666233539581%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 6\n", + "Benign Train set: Avg. loss: 0.8374395554370069, Accuracy: 4245/6000 (70.74999809265137%)\n", + "Benign Test set: Avg. loss: 0.983585943778356, Accuracy: 3876/6000 (64.59999680519104%)\n", + "Backdoor Test set: Avg. loss: 5.612242301305135, Accuracy: 362/6000 (6.033333390951157%)\n", + "Benign Train set: Avg. loss: 0.8825467029150497, Accuracy: 3966/6000 (66.10000133514404%)\n", + "Benign Test set: Avg. loss: 1.172180672486623, Accuracy: 3436/6000 (57.26666450500488%)\n", + "Backdoor Test set: Avg. loss: 8.384706020355225, Accuracy: 75/6000 (1.2500000186264515%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 6\n", + "Benign Train set: Avg. loss: 3.2969402074813843, Accuracy: 2109/6000 (35.1500004529953%)\n", + "Benign Test set: Avg. loss: 0.9908452332019806, Accuracy: 3849/6000 (64.14999961853027%)\n", + "Backdoor Test set: Avg. loss: 5.578811009724935, Accuracy: 379/6000 (6.316666305065155%)\n", + "Benign Train set: Avg. loss: 0.6076196164209792, Accuracy: 4612/6000 (76.8666684627533%)\n", + "Benign Test set: Avg. loss: 1.2845233082771301, Accuracy: 3085/6000 (51.41666531562805%)\n", + "Backdoor Test set: Avg. loss: 0.014343314183255037, Accuracy: 5998/6000 (99.96666312217712%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 6\n", + "Distance: cosine, use y 2: [3.1959195824794913, 0.12506571720480553]\n", + "Distance: cosine, use x 2: [18.66126607580756, 0.12506571720480597]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.11455014060527713, 0.04188976487882634]\n", + "Distance: euclid, use x 2: [20.493965183272902, 0.04188976487882545]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 6\n", + "Distance: cosine, use y 2: [0.15689968394637965, 0.6351126526557955]\n", + "Distance: cosine, use x 2: [20.02931110591679, 0.15689968394637965]\n", + "Distance: cosine, use y 1: [10.954451150103306]\n", + "Distance: cosine, use x 2: [10.954451150103333, 0.0]\n", + "Distance: euclid, use y 2: [0.07054778219426705, 0.08106611839147426]\n", + "Distance: euclid, use x 2: [20.60137755938585, 0.07054778219426705]\n", + "Distance: euclid, use y 1: [10.954451150103335]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 6\n", + "Distance: cosine, use y 2: [0.11406124941259943, 2.9178387866028697]\n", + "Distance: cosine, use x 2: [18.414472367210298, 0.11406124941259987]\n", + "Distance: cosine, use y 1: [10.954451150103317]\n", + "Distance: cosine, use x 2: [10.95445115010332, 0.0]\n", + "Distance: euclid, use y 2: [0.004118947111889426, 0.3445463512501741]\n", + "Distance: euclid, use x 2: [20.33123125017665, 0.004118947111888538]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 6\n", + "Distance: cosine, use y 2: [1.004308933278638, 2.8780471952495597]\n", + "Distance: cosine, use x 2: [15.634467194987792, 1.004308933278638]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.7226242922769606, 4.994200594550494]\n", + "Distance: euclid, use x 2: [15.979720687217252, 0.722624292276961]\n", + "Distance: euclid, use y 1: [10.954451150103328]\n", + "Distance: euclid, use x 2: [10.954451150103335, 0.0]\n", + "Suspicious Models detected by 3: [1]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 7/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 7\n", + "Benign Train set: Avg. loss: 0.7909779535963181, Accuracy: 4257/6000 (70.95000147819519%)\n", + "Benign Test set: Avg. loss: 0.9977259238560995, Accuracy: 3812/6000 (63.53332996368408%)\n", + "Backdoor Test set: Avg. loss: 6.6386739412943525, Accuracy: 255/6000 (4.250000044703484%)\n", + "Benign Train set: Avg. loss: 0.9090553227257221, Accuracy: 3999/6000 (66.64999723434448%)\n", + "Benign Test set: Avg. loss: 1.183224380016327, Accuracy: 3433/6000 (57.216668128967285%)\n", + "Backdoor Test set: Avg. loss: 5.912643750508626, Accuracy: 350/6000 (5.833333358168602%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 7\n", + "Benign Train set: Avg. loss: 0.814667098382686, Accuracy: 4235/6000 (70.58333158493042%)\n", + "Benign Test set: Avg. loss: 1.0011651416619618, Accuracy: 3827/6000 (63.7833297252655%)\n", + "Backdoor Test set: Avg. loss: 6.532990614573161, Accuracy: 263/6000 (4.383333399891853%)\n", + "Benign Train set: Avg. loss: 0.8959275081436685, Accuracy: 3980/6000 (66.33332967758179%)\n", + "Benign Test set: Avg. loss: 1.231977681318919, Accuracy: 3323/6000 (55.38333058357239%)\n", + "Backdoor Test set: Avg. loss: 6.533595641454061, Accuracy: 171/6000 (2.850000001490116%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 7\n", + "Benign Train set: Avg. loss: 0.8047349789041154, Accuracy: 4334/6000 (72.2333312034607%)\n", + "Benign Test set: Avg. loss: 0.9804023404916128, Accuracy: 3884/6000 (64.73333239555359%)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backdoor Test set: Avg. loss: 6.598117272059123, Accuracy: 265/6000 (4.416666552424431%)\n", + "Benign Train set: Avg. loss: 0.8541606298469483, Accuracy: 4111/6000 (68.51666569709778%)\n", + "Benign Test set: Avg. loss: 1.1488163471221924, Accuracy: 3502/6000 (58.36666822433472%)\n", + "Backdoor Test set: Avg. loss: 7.346678733825684, Accuracy: 62/6000 (1.0333333164453506%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 7\n", + "Benign Train set: Avg. loss: 3.7933548435251763, Accuracy: 2059/6000 (34.316664934158325%)\n", + "Benign Test set: Avg. loss: 0.9883201122283936, Accuracy: 3818/6000 (63.63333463668823%)\n", + "Backdoor Test set: Avg. loss: 6.588667392730713, Accuracy: 276/6000 (4.600000008940697%)\n", + "Benign Train set: Avg. loss: 0.5654542832932574, Accuracy: 4768/6000 (79.46666479110718%)\n", + "Benign Test set: Avg. loss: 1.2047673066457112, Accuracy: 3285/6000 (54.750001430511475%)\n", + "Backdoor Test set: Avg. loss: 0.03584048183013996, Accuracy: 5962/6000 (99.36666488647461%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 7\n", + "Distance: cosine, use y 2: [4.550742342756802, 0.2734047995577864]\n", + "Distance: cosine, use x 2: [17.131086555275665, 0.27340479955778685]\n", + "Distance: cosine, use y 1: [10.954451150103313]\n", + "Distance: cosine, use x 2: [10.954451150103317, 0.0]\n", + "Distance: euclid, use y 2: [0.06440376446433316, 0.4009909516764516]\n", + "Distance: euclid, use x 2: [20.4890175814253, 0.06440376446433316]\n", + "Distance: euclid, use y 1: [10.954451150103315]\n", + "Distance: euclid, use x 2: [10.954451150103324, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 7\n", + "Distance: cosine, use y 2: [1.8193887031983578, 1.206869789505315]\n", + "Distance: cosine, use x 2: [15.456868628253602, 1.206869789505315]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.16970625322009436, 0.06365335148335571]\n", + "Distance: euclid, use x 2: [20.582243600502544, 0.06365335148335483]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 7\n", + "Distance: cosine, use y 2: [2.287530253603258, 1.1780715731875366]\n", + "Distance: cosine, use x 2: [15.130761367234994, 1.1780715731875362]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.43618193721079734, 0.0809130511347993]\n", + "Distance: euclid, use x 2: [20.467251355554314, 0.0809130511347993]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 7\n", + "Distance: cosine, use y 2: [4.928247882158452, 0.8293360690090639]\n", + "Distance: cosine, use x 2: [15.459903350249837, 0.8293360690090641]\n", + "Distance: cosine, use y 1: [10.954451150103315]\n", + "Distance: cosine, use x 2: [10.954451150103345, 0.0]\n", + "Distance: euclid, use y 2: [6.016499289987826, 0.40066896841297606]\n", + "Distance: euclid, use x 2: [15.774247119647548, 0.4006689684129756]\n", + "Distance: euclid, use y 1: [10.954451150103194]\n", + "Distance: euclid, use x 2: [10.954451150103461, 0.0]\n", + "Suspicious Models detected by 3: [0]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 8/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 8\n", + "Benign Train set: Avg. loss: 0.7790621568230872, Accuracy: 4358/6000 (72.63333201408386%)\n", + "Benign Test set: Avg. loss: 1.0014758507410686, Accuracy: 3824/6000 (63.733333349227905%)\n", + "Backdoor Test set: Avg. loss: 6.370962937672933, Accuracy: 208/6000 (3.466666489839554%)\n", + "Benign Train set: Avg. loss: 0.7508002264385528, Accuracy: 4372/6000 (72.86666631698608%)\n", + "Benign Test set: Avg. loss: 1.097122073173523, Accuracy: 3609/6000 (60.14999747276306%)\n", + "Backdoor Test set: Avg. loss: 6.9121631781260175, Accuracy: 271/6000 (4.516666755080223%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 8\n", + "Benign Train set: Avg. loss: 0.7930318139334942, Accuracy: 4332/6000 (72.2000002861023%)\n", + "Benign Test set: Avg. loss: 1.015789379676183, Accuracy: 3775/6000 (62.91666626930237%)\n", + "Backdoor Test set: Avg. loss: 6.28120477994283, Accuracy: 179/6000 (2.983333356678486%)\n", + "Benign Train set: Avg. loss: 0.7537304603673042, Accuracy: 4306/6000 (71.76666855812073%)\n", + "Benign Test set: Avg. loss: 1.0869876742362976, Accuracy: 3616/6000 (60.26666760444641%)\n", + "Backdoor Test set: Avg. loss: 6.068501631418864, Accuracy: 412/6000 (6.866666674613953%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 8\n", + "Benign Train set: Avg. loss: 0.7983413157944984, Accuracy: 4314/6000 (71.8999981880188%)\n", + "Benign Test set: Avg. loss: 0.982840488354365, Accuracy: 3872/6000 (64.53333497047424%)\n", + "Backdoor Test set: Avg. loss: 6.347102244695027, Accuracy: 213/6000 (3.5500001162290573%)\n", + "Benign Train set: Avg. loss: 0.7916458041426984, Accuracy: 4260/6000 (70.99999785423279%)\n", + "Benign Test set: Avg. loss: 1.099749505519867, Accuracy: 3648/6000 (60.79999804496765%)\n", + "Backdoor Test set: Avg. loss: 6.888438145319621, Accuracy: 86/6000 (1.4333332888782024%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 8\n", + "Benign Train set: Avg. loss: 3.675917237997055, Accuracy: 2010/6000 (33.50000083446503%)\n", + "Benign Test set: Avg. loss: 0.9934276640415192, Accuracy: 3835/6000 (63.91666531562805%)\n", + "Backdoor Test set: Avg. loss: 6.315940618515015, Accuracy: 209/6000 (3.4833334386348724%)\n", + "Benign Train set: Avg. loss: 0.5355115555068279, Accuracy: 4838/6000 (80.63333034515381%)\n", + "Benign Test set: Avg. loss: 1.1687040726343791, Accuracy: 3400/6000 (56.66666626930237%)\n", + "Backdoor Test set: Avg. loss: 0.02521951838086049, Accuracy: 5985/6000 (99.75000023841858%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 8\n", + "Distance: cosine, use y 2: [2.0330273940615093, 0.783919831778948]\n", + "Distance: cosine, use x 2: [18.796337259156225, 0.7839198317789475]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.07222110281966465, 0.0361506101254232]\n", + "Distance: euclid, use x 2: [20.62949826235959, 0.03615061012542409]\n", + "Distance: euclid, use y 1: [10.954451150103333]\n", + "Distance: euclid, use x 2: [10.954451150103331, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 8\n", + "Distance: cosine, use y 2: [0.7246317181099711, 1.4790219293692735]\n", + "Distance: cosine, use x 2: [18.932616050326253, 0.7246317181099711]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.1746180787639684, 0.005267232904717645]\n", + "Distance: euclid, use x 2: [20.596317810001032, 0.005267232904717645]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 8\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distance: cosine, use y 2: [0.24431579576776663, 5.2569501477100475]\n", + "Distance: cosine, use x 2: [16.76755929301521, 0.2443157957677662]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [0.07282139662990428, 0.0038479150123968964]\n", + "Distance: euclid, use x 2: [20.63099133657362, 0.0038479150123968964]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 8\n", + "Distance: cosine, use y 2: [0.8265963692700686, 7.309482023131439]\n", + "Distance: cosine, use x 2: [12.158018104786466, 0.8265963692700686]\n", + "Distance: cosine, use y 1: [10.954451150103342]\n", + "Distance: cosine, use x 2: [10.95445115010333, 0.0]\n", + "Distance: euclid, use y 2: [1.7396525584425433, 4.420115231286226]\n", + "Distance: euclid, use x 2: [14.211445576787916, 1.7396525584425433]\n", + "Distance: euclid, use y 1: [10.954451150102678]\n", + "Distance: euclid, use x 2: [10.95445115010399, 0.0]\n", + "Suspicious Models detected by 3: [1, 2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "Finished round 9/10\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_00 in round 9\n", + "Benign Train set: Avg. loss: 0.7373616132647434, Accuracy: 4426/6000 (73.7666666507721%)\n", + "Benign Test set: Avg. loss: 0.9840383330980936, Accuracy: 3831/6000 (63.84999752044678%)\n", + "Backdoor Test set: Avg. loss: 6.371568520863851, Accuracy: 282/6000 (4.699999839067459%)\n", + "Benign Train set: Avg. loss: 0.7673709408399907, Accuracy: 4325/6000 (72.08333015441895%)\n", + "Benign Test set: Avg. loss: 1.117681125799815, Accuracy: 3500/6000 (58.33333134651184%)\n", + "Backdoor Test set: Avg. loss: 6.13549280166626, Accuracy: 459/6000 (7.6499998569488525%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_01 in round 9\n", + "Benign Train set: Avg. loss: 0.7476775199174881, Accuracy: 4382/6000 (73.03333282470703%)\n", + "Benign Test set: Avg. loss: 1.0023195147514343, Accuracy: 3808/6000 (63.466668128967285%)\n", + "Backdoor Test set: Avg. loss: 6.306597391764323, Accuracy: 269/6000 (4.483333230018616%)\n", + "Benign Train set: Avg. loss: 0.8192861509766984, Accuracy: 4132/6000 (68.86666417121887%)\n", + "Benign Test set: Avg. loss: 1.1944966117540996, Accuracy: 3464/6000 (57.733333110809326%)\n", + "Backdoor Test set: Avg. loss: 6.477320432662964, Accuracy: 492/6000 (8.20000022649765%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator benign_02 in round 9\n", + "Benign Train set: Avg. loss: 0.7560380529215995, Accuracy: 4401/6000 (73.35000038146973%)\n", + "Benign Test set: Avg. loss: 0.9617986083030701, Accuracy: 3934/6000 (65.56666493415833%)\n", + "Backdoor Test set: Avg. loss: 6.3540685176849365, Accuracy: 283/6000 (4.716666787862778%)\n", + "Benign Train set: Avg. loss: 0.8863914824546651, Accuracy: 4003/6000 (66.71666502952576%)\n", + "Benign Test set: Avg. loss: 1.230345328648885, Accuracy: 3418/6000 (56.966668367385864%)\n", + "Backdoor Test set: Avg. loss: 8.166247129440308, Accuracy: 100/6000 (1.666666567325592%)\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling train\n", + "####################\n", + "Performing model training for collaborator malicious_00 in round 9\n", + "Benign Train set: Avg. loss: 3.661377598630621, Accuracy: 2062/6000 (34.3666672706604%)\n", + "Benign Test set: Avg. loss: 0.977350095907847, Accuracy: 3840/6000 (63.999998569488525%)\n", + "Backdoor Test set: Avg. loss: 6.344151099522908, Accuracy: 275/6000 (4.583333432674408%)\n", + "Benign Train set: Avg. loss: 0.5469587276432109, Accuracy: 4805/6000 (80.08333444595337%)\n", + "Benign Test set: Avg. loss: 1.229085087776184, Accuracy: 3243/6000 (54.04999852180481%)\n", + "Backdoor Test set: Avg. loss: 0.010714995209127665, Accuracy: 5995/6000 (99.91666674613953%)\n", + "Scale Model by 4.0\n", + "Should transfer from train to collect_models\n", + "\n", + "Calling collect_models\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_00 in round 9\n", + "Distance: cosine, use y 2: [2.7217097868431646, 0.48909488332725637]\n", + "Distance: cosine, use x 2: [18.132140408145283, 0.4890948833272559]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [1.929935169406828, 1.3124232630177928]\n", + "Distance: euclid, use x 2: [16.533672920014247, 1.3124232630177928]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 0: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_01 in round 9\n", + "Distance: cosine, use y 2: [1.9797955990002887, 0.5455339631621223]\n", + "Distance: cosine, use x 2: [18.993872137326505, 0.5455339631621232]\n", + "Distance: cosine, use y 1: [10.95445115010334]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [5.058928745362662, 0.025827180352623547]\n", + "Distance: euclid, use x 2: [14.406930898232122, 0.025827180352623547]\n", + "Distance: euclid, use y 1: [10.95445115010334]\n", + "Distance: euclid, use x 2: [10.954451150103328, 0.0]\n", + "Suspicious Models detected by 1: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator benign_02 in round 9\n", + "Distance: cosine, use y 2: [0.6147782894408031, 2.1820622362089743]\n", + "Distance: cosine, use x 2: [18.351482792150968, 0.6147782894408023]\n", + "Distance: cosine, use y 1: [10.954451150103338]\n", + "Distance: cosine, use x 2: [10.954451150103328, 0.0]\n", + "Distance: euclid, use y 2: [2.4907861807277247, 0.3759675577613226]\n", + "Distance: euclid, use x 2: [16.371604563916964, 0.3759675577613226]\n", + "Distance: euclid, use y 1: [10.954451150103338]\n", + "Distance: euclid, use x 2: [10.95445115010333, 0.0]\n", + "Suspicious Models detected by 2: [3]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling local_validation\n", + "Performing model validation for collaborator malicious_00 in round 9\n", + "Distance: cosine, use y 2: [6.150505334609076, 1.3753749625020746]\n", + "Distance: cosine, use x 2: [13.370018249411753, 1.3753749625020746]\n", + "Distance: cosine, use y 1: [10.9544511501033]\n", + "Distance: cosine, use x 2: [10.954451150103354, 0.0]\n", + "Distance: euclid, use y 2: [10.310470551411303, 0.37635097834506465]\n", + "Distance: euclid, use x 2: [11.47433869492199, 0.37635097834506465]\n", + "Distance: euclid, use y 1: [10.954451150103175]\n", + "Distance: euclid, use x 2: [10.954451150103484, 0.0]\n", + "Suspicious Models detected by 3: [2]\n", + "Should transfer from local_validation to defend\n", + "\n", + "Calling defend\n", + "Agglomerative Clustering: {0: array([0, 1, 2]), 1: array([3])}\n", + "DBScan Input: [0 1 2]\n", + "DBScan Clustering: [0 1 2]\n", + "Negatives: [0, 1, 2]\n", + "\n", + "Calling end\n", + "####################\n", + "All rounds completed successfully\n", + "####################\n", + "This is the end of the flow\n", + "####################\n" + ] + } + ], + "source": [ + "local_runtime = LocalRuntime(aggregator=aggregator_object, collaborators=collaborators)\n", + "\n", + "print(f\"Local runtime collaborators = {local_runtime.collaborators}\")\n", + "\n", + "# change to the internal flow loop\n", + "model = Net()\n", + "model.load_state_dict(pretrained_weights)\n", + "top_model_accuracy = 0\n", + "optimizers = {\n", + " collaborator.name: default_optimizer(model, optimizer_type=args.optimizer_type)\n", + " for collaborator in collaborators\n", + " }\n", + "flflow = FederatedFlow(\n", + " model,\n", + " optimizers,\n", + " device,\n", + " args.comm_round,\n", + " top_model_accuracy,\n", + " NUMBER_OF_MALICIOUS_CLIENTS / TOTAL_CLIENT_NUMBER,\n", + " 'CrowdGuard'\n", + " )\n", + "\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl-tutorials/experimental/CrowdGuard/cifar10_crowdguard.py b/openfl-tutorials/experimental/CrowdGuard/cifar10_crowdguard.py new file mode 100644 index 00000000000..5e5901367b2 --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/cifar10_crowdguard.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python +# coding: utf-8 + +# Copyright (C) 2022-2024 TU Darmstadt +# SPDX-License-Identifier: Apache-2.0 + +# ----------------------------------------------------------- +# Primary author: Phillip Rieger +# Co-authored-by: Torsten Krauss +# ------------------------------------------------------------ + +import argparse +import random +import time +import warnings + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset +import torch.optim as optim +from torchvision import transforms, datasets +from sklearn.cluster import AgglomerativeClustering, DBSCAN + +from CrowdGuardClientValidation import CrowdGuardClientValidation +from openfl.experimental.interface import Aggregator, Collaborator, FLSpec +from openfl.experimental.placement import aggregator, collaborator +from openfl.experimental.runtime import LocalRuntime +from urllib.request import urlretrieve + +warnings.filterwarnings("ignore") + +BATCH_SIZE_TRAIN = 32 +BATCH_SIZE_TEST = 1000 +LEARNING_RATE = 0.00075 +MOMENTUM = 0.9 +LOG_INTERVAL = 10 +TOTAL_CLIENT_NUMBER = 4 +PMR = 0.25 +NUMBER_OF_MALICIOUS_CLIENTS = max(1, int(TOTAL_CLIENT_NUMBER * PMR)) if PMR > 0 else 0 +NUMBER_OF_BENIGN_CLIENTS = TOTAL_CLIENT_NUMBER - NUMBER_OF_MALICIOUS_CLIENTS +PRETRAINED_MODEL_FILE = 'pretrained_cifar.pt' + +# set the random seed for repeatable results +RANDOM_SEED = 10 + +VOTE_FOR_BENIGN = 1 +VOTE_FOR_POISONED = 0 +STD_DEV = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])) +MEAN = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])) + + +def download_pretrained_model(): + urlretrieve('https://huggingface.co/prieger/cifar10/resolve/main/pretrained_cifar.pt?' + 'download=true', PRETRAINED_MODEL_FILE) + + +def trigger_single_image(image): + """ + Adds a red square with a height/width of 6 pixels into + the upper left corner of the given image. + @param image tensor, containing the normalized pixel values of the image. + The image will be modified in-place. + @return given image + """ + color = (torch.Tensor((1, 0, 0)) - MEAN) / STD_DEV + image[:, 0:6, 0:6] = color.repeat((6, 6, 1)).permute(2, 1, 0) + return image + + +def poison_data(samples_to_poison, labels_to_poison, pdr=0.5): + """ + poisons a given local dataset, consisting of samples and labels, s.t., + the given ratio of this image consists of samples for the backdoor behavior + :param samples_to_poison tensor containing all samples of the local dataset + :param labels_to_poison tensor containing all labels + :param pdr poisoned data rate + :return poisoned local dataset (samples, labels) + """ + if pdr == 0: + return samples_to_poison, labels_to_poison + + assert 0 < pdr <= 1.0 + samples_to_poison = samples_to_poison.clone() + labels_to_poison = labels_to_poison.clone() + + dataset_size = samples_to_poison.shape[0] + num_samples_to_poison = int(dataset_size * pdr) + if num_samples_to_poison == 0: + # corner case for tiny pdrs + assert pdr > 0 # Already checked above + assert dataset_size > 1 + num_samples_to_poison += 1 + + indices = np.random.choice(dataset_size, size=num_samples_to_poison, replace=False) + for image_index in indices: + image = trigger_single_image(samples_to_poison[image_index]) + samples_to_poison[image_index] = image + labels_to_poison[indices] = 2 + return samples_to_poison, labels_to_poison.long() + + +class SequentialWithInternalStatePrediction(nn.Sequential): + """ + Adapted version of Sequential that implements the function predict_internal_states + """ + + def predict_internal_states(self, x): + """ + applies the submodules on the input. Compared to forward, this function also returns + all intermediate outputs + """ + result = [] + for module in self: + x = module(x) + # We can define our layer as we want. We selected Convolutional and + # Linear Modules as layers here. + # Differs for every model architecture. + # Can be defined by the defender. + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + result.append(x) + return result, x + + +class Net(nn.Module): + def __init__(self, num_classes=10): + super(Net, self).__init__() + self.features = SequentialWithInternalStatePrediction( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + nn.Conv2d(64, 192, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + ) + self.classifier = SequentialWithInternalStatePrediction( + nn.Dropout(), + nn.Linear(256 * 2 * 2, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256 * 2 * 2) + x = self.classifier(x) + return x + + def predict_internal_states(self, x): + result, x = self.features.predict_internal_states(x) + x = x.view(x.size(0), 256 * 2 * 2) + result += self.classifier.predict_internal_states(x)[0] + return result + + +def default_optimizer(model, optimizer_type=None, optimizer_like=None): + """ + Return a new optimizer based on the optimizer_type or the optimizer template + + Args: + model: NN model architected from nn.module class + optimizer_type: "SGD" or "Adam" + optimizer_like: "torch.optim.SGD" or "torch.optim.Adam" optimizer + """ + if optimizer_type == "SGD" or isinstance(optimizer_like, optim.SGD): + return optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) + elif optimizer_type == "Adam" or isinstance(optimizer_like, optim.Adam): + return optim.Adam(model.parameters()) + + +def test(network, test_loader, device, mode='Benign', move_to_cpu_afterward=True, + test_train='Test'): + network.eval() + network.to(device) + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data = data.to(device) + target = target.to(device) + output = network(data) + criterion = nn.CrossEntropyLoss() + test_loss += criterion(output, target).item() + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).sum() + test_loss /= len(test_loader) + accuracy = float(correct / len(test_loader.dataset)) + print( + ( + f"{mode} {test_train} set: Avg. loss: {test_loss}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * accuracy:5.03f}%)" + ) + ) + if move_to_cpu_afterward: + network.to("cpu") + return accuracy + + +def FedAvg(models): # NOQA: N802 + """ + Return a Federated average model based on Fedavg algorithm: H. B. Mcmahan, + E. Moore, D. Ramage, S. Hampson, and B. A. Y.Arcas, + “Communication-efficient learning of deep networks from decentralized data,” 2017. + + Args: + models: Python list of locally trained models by each collaborator + """ + new_model = models[0] + if len(models) > 1: + state_dicts = [model.state_dict() for model in models] + state_dict = new_model.state_dict() + for key in models[1].state_dict(): + state_dict[key] = np.sum( + [state[key] for state in state_dicts], axis=0 + ) / len(models) + new_model.load_state_dict(state_dict) + return new_model + + +def scale_update_of_model(to_scale, global_model, scaling_factor): + """ + Scales the update of a local model (thus the difference between global and local model) + :param to_scale: local model as state dict + :pram global_model + :param scaling factor + :return scaled local model as state dict + """ + print(f'Scale Model by {scaling_factor}') + result = {} + for name, data in to_scale.items(): + if not (name.endswith('.bias') or name.endswith('.weight')): + result[name] = data + else: + update = data - global_model[name] + scaled = scaling_factor * update + result[name] = scaled + global_model[name] + return result + + +def create_cluster_map_from_labels(expected_number_of_labels, clustering_labels): + """ + Converts a list of labels into a dictionary where each label is the key and + the values are lists/np arrays of the indices from the samples that received + the respective label + :param expected_number_of_labels number of samples whose labels are contained in + clustering_labels + :param clustering_labels list containing the labels of each sample + :return dictionary of clusters + """ + assert len(clustering_labels) == expected_number_of_labels + + clusters = {} + for i, cluster in enumerate(clustering_labels): + if cluster not in clusters: + clusters[cluster] = [] + clusters[cluster].append(i) + return {index: np.array(cluster) for index, cluster in clusters.items()} + + +def determine_biggest_cluster(clustering): + """ + Given a clustering, given as dictionary of the form {cluster_id: [items in cluster]}, the + function returns the id of the biggest cluster + """ + biggest_cluster_id = None + biggest_cluster_size = None + for cluster_id, cluster in clustering.items(): + size_of_current_cluster = np.array(cluster).shape[0] + if biggest_cluster_id is None or size_of_current_cluster > biggest_cluster_size: + biggest_cluster_id = cluster_id + biggest_cluster_size = size_of_current_cluster + return biggest_cluster_id + + +class FederatedFlow(FLSpec): + def __init__(self, model, optimizers, device="cpu", total_rounds=10, top_model_accuracy=0, + pmr=0.25, aggregation_algorithm='FedAVG', **kwargs, ): + if aggregation_algorithm not in ['FedAVG', 'CrowdGuard']: + raise Exception(f'Unsupported Aggregation Algorithm: {aggregation_algorithm}') + super().__init__(**kwargs) + self.aggregation_algorithm = aggregation_algorithm + self.model = model + self.global_model = Net() + self.pmr = pmr + self.start_time = None + self.collaborators = None + self.private = None + self.optimizers = optimizers + self.total_rounds = total_rounds + self.top_model_accuracy = top_model_accuracy + self.device = device + self.round_num = 0 # starting round + print(20 * "#") + print(f"Round {self.round_num}...") + print(20 * "#") + + @aggregator + def start(self): + self.start_time = time.time() + print("Performing initialization for model") + self.collaborators = self.runtime.collaborators + self.private = 10 + self.next( + self.train, + foreach="collaborators", + exclude=["private"], + ) + + # @collaborator # Uncomment if you want ro run on CPU + @collaborator(num_gpus=1) # Assuming GPU(s) is available on the machine + def train(self): + self.collaborator_name = self.input + print(20 * "#") + print(f"Performing model training for collaborator {self.input} in round {self.round_num}") + + self.model.to(self.device) + original_model = {n: d.clone() for n, d in self.model.state_dict().items()} + test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False, + test_train='Train') + test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False) + test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor', + move_to_cpu_afterward=False) + self.optimizer = default_optimizer(self.model, optimizer_like=self.optimizers[self.input]) + + self.model.train() + train_losses = [] + for batch_idx, (data, target) in enumerate(self.train_loader): + data = data.to(self.device) + target = target.to(self.device) + self.optimizer.zero_grad() + output = self.model(data) + criterion = nn.CrossEntropyLoss() + loss = criterion(output, target).to(self.device) + loss.backward() + self.optimizer.step() + if batch_idx % LOG_INTERVAL == 0: + train_losses.append(loss.item()) + + self.loss = np.mean(train_losses) + self.training_completed = True + + test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False, + test_train='Train') + test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False) + test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor', + move_to_cpu_afterward=False) + if 'malicious' in self.input: + weights = self.model.state_dict() + scaled = scale_update_of_model(weights, original_model, 1 / self.pmr) + self.model.load_state_dict(scaled) + self.model.to("cpu") + torch.cuda.empty_cache() + if self.aggregation_algorithm == 'FedAVG': + self.next(self.fed_avg_aggregation, exclude=["training_completed"]) + else: + self.next(self.collect_models, exclude=["training_completed"]) + + @aggregator + def fed_avg_aggregation(self, inputs): + self.all_models = {input.collaborator_name: input.model.cpu() for input in inputs} + self.model = FedAvg([m.cpu() for m in self.all_models.values()]) + self.round_num += 1 + if self.round_num + 1 < self.total_rounds: + self.next(self.train, foreach="collaborators") + else: + self.next(self.end) + + @aggregator + def collect_models(self, inputs): + # Following the CrowdGuard paper, this should be executed within SGX + + self.all_models = {i.collaborator_name: i.model.cpu() for i in inputs} + self.next(self.local_validation, foreach="collaborators") + + @collaborator + def local_validation(self): + # Following the CrowdGuard paper, this should be executed within SGX + print( + f"Performing model validation for collaborator {self.input} in round {self.round_num}" + ) + self.collaborator_name = self.input + all_names = list(self.all_models.keys()) + all_models = [self.all_models[n] for n in all_names] + own_client_index = all_names.index(self.collaborator_name) + detected_suspicious_models = CrowdGuardClientValidation.validate_models(self.global_model, + all_models, + own_client_index, + self.train_loader, + self.device) + detected_suspicious_models = sorted(detected_suspicious_models) + print( + f'Suspicious Models detected by {own_client_index}: {detected_suspicious_models}') + + votes_of_this_client = [] + for c in range(len(all_models)): + if c == own_client_index: + votes_of_this_client.append(VOTE_FOR_BENIGN) + elif c in detected_suspicious_models: + votes_of_this_client.append(VOTE_FOR_POISONED) + else: + votes_of_this_client.append(VOTE_FOR_BENIGN) + self.votes_of_this_client = {} + for name, vote in zip(all_names, votes_of_this_client): + self.votes_of_this_client[name] = vote + + self.next(self.defend) + + @aggregator + def defend(self, inputs): + # Following the CrowdGuard paper, this should be executed within SGX + + all_names = list(self.all_models.keys()) + all_votes_by_name = {i.collaborator_name: i.votes_of_this_client for i in inputs} + + all_models = [self.all_models[name] for name in all_names] + binary_votes = [[all_votes_by_name[own_name][val_name] for val_name in all_names] for + own_name in all_names] + + ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None, + compute_full_tree=True, + affinity="euclidean", memory=None, connectivity=None, + linkage='single', + compute_distances=True).fit(binary_votes) + ac_e_labels: list = ac_e.labels_.tolist() + agglomerative_result = create_cluster_map_from_labels(len(all_names), ac_e_labels) + print(f'Agglomerative Clustering: {agglomerative_result}') + agglomerative_negative_cluster = agglomerative_result[ + determine_biggest_cluster(agglomerative_result)] + + db_scan_input_idx_list = agglomerative_negative_cluster + print(f'DBScan Input: {db_scan_input_idx_list}') + db_scan_input_list = [binary_votes[vote_id] for vote_id in db_scan_input_idx_list] + + db = DBSCAN(eps=0.5, min_samples=1).fit(db_scan_input_list) + dbscan_clusters = create_cluster_map_from_labels(len(agglomerative_negative_cluster), + db.labels_.tolist()) + biggest_dbscan_cluster = dbscan_clusters[determine_biggest_cluster(dbscan_clusters)] + print(f'DBScan Clustering: {biggest_dbscan_cluster}') + + single_sample_of_biggest_cluster = biggest_dbscan_cluster[0] + final_voting = db_scan_input_list[single_sample_of_biggest_cluster] + negatives = [i for i, vote in enumerate(final_voting) if vote == VOTE_FOR_BENIGN] + recognized_benign_models = [all_models[n] for n in negatives] + + print(f'Negatives: {negatives}') + + self.model = FedAvg([m.cpu() for m in recognized_benign_models]) + del inputs + self.round_num += 1 + if self.round_num < self.total_rounds: + print(f'Finished round {self.round_num}/{self.total_rounds}') + self.next(self.train, foreach="collaborators") + else: + self.next(self.end) + + @aggregator + def end(self): + print(20 * "#") + print("All rounds completed successfully") + print(20 * "#") + print("This is the end of the flow") + print(20 * "#") + + +def seed_random_generators(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +if __name__ == '__main__': + + seed_random_generators(RANDOM_SEED) + + argparser = argparse.ArgumentParser(description=__doc__) + argparser.add_argument( + "--test_dataset_ratio", + type=float, + default=0.4, + help="Indicate the what fraction of the sample will be used for testing", + ) + argparser.add_argument( + "--train_dataset_ratio", + type=float, + default=0.4, + help="Indicate the what fraction of the sample will be used for training", + ) + + argparser.add_argument( + "--log_dir", + type=str, + default="test_debug", + help="Indicate where to save the privacy loss profile and log files during the training", + ) + argparser.add_argument( + "--comm_round", + type=int, + default=30, + help="Indicate the communication round of FL", + ) + argparser.add_argument( + "--optimizer_type", + type=str, + default="SGD", + help="Indicate optimizer to use for training", + ) + + args = argparser.parse_args() + + download_pretrained_model() + + # Setup participants + aggregator_object = Aggregator() + aggregator_object.private_attributes = {} + collaborator_names = [f'benign_{i:02d}' for i in range(NUMBER_OF_BENIGN_CLIENTS)] + [ + f'malicious_{i:02d}' for i in range(NUMBER_OF_MALICIOUS_CLIENTS)] + collaborators = [Collaborator(name=name) for name in collaborator_names] + if torch.cuda.is_available(): + device = torch.device( + "cuda:1" + ) # This will enable Ray library to reserve available GPU(s) for the task + else: + device = torch.device("cpu") + + # Prepare local datasets + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD_DEV), ]) + cifar_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) + cifar_train = list(cifar_train) + cifar_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) + cifar_test = list(cifar_test) + X = torch.stack([x[0] for x in cifar_train] + [x[0] for x in cifar_test]) + Y = torch.LongTensor( + np.stack(np.array([x[1] for x in cifar_train] + [x[1] for x in cifar_test]))) + + # split the dataset + seed_random_generators(RANDOM_SEED) + shuffled_indices = np.arange(X.shape[0]) + np.random.shuffle(shuffled_indices) + + N_total_samples = len(cifar_test) + len(cifar_train) + train_dataset_size = int(N_total_samples * args.train_dataset_ratio) + test_dataset_size = int(N_total_samples * args.test_dataset_ratio) + X = X[shuffled_indices] + Y = Y[shuffled_indices] + + train_dataset_data = X[:train_dataset_size] + train_dataset_targets = Y[:train_dataset_size] + + test_dataset_data = X[train_dataset_size:train_dataset_size + test_dataset_size] + test_dataset_targets = Y[train_dataset_size:train_dataset_size + test_dataset_size] + print(f"Dataset info (total {N_total_samples}): train - {test_dataset_targets.shape[0]}, " + f"test - {test_dataset_targets.shape[0]}, ") + + # partition the dataset for clients + + for idx, collab in enumerate(collaborators): + # construct the training and test and population dataset + benign_training_x = train_dataset_data[idx::len(collaborators)] + benign_training_y = train_dataset_targets[idx::len(collaborators)] + + if 'malicious' in collab.name: + local_train_data, local_train_targets = poison_data(benign_training_x, + benign_training_y) + else: + local_train_data, local_train_targets = benign_training_x, benign_training_y + + local_test_data = test_dataset_data[idx::len(collaborators)] + local_test_targets = test_dataset_targets[idx::len(collaborators)] + + poison_test_data, poison_test_targets = poison_data(local_test_data, local_test_targets, + pdr=1.0) + + collab.private_attributes = { + "train_loader": torch.utils.data.DataLoader( + TensorDataset(local_train_data, local_train_targets), + batch_size=BATCH_SIZE_TRAIN, shuffle=True + ), + "test_loader": torch.utils.data.DataLoader( + TensorDataset(local_test_data, local_test_targets), + batch_size=BATCH_SIZE_TEST, shuffle=False + ), + "backdoor_test_loader": torch.utils.data.DataLoader( + TensorDataset(poison_test_data, poison_test_targets), + batch_size=BATCH_SIZE_TEST, shuffle=False + ), + } + + local_runtime = LocalRuntime(aggregator=aggregator_object, collaborators=collaborators) + + print(f"Local runtime collaborators = {local_runtime.collaborators}") + + # change to the internal flow loop + model = Net() + top_model_accuracy = 0 + optimizers = { + collaborator.name: default_optimizer(model, optimizer_type=args.optimizer_type) + for collaborator in collaborators + } + flflow = FederatedFlow( + model, + optimizers, + device, + args.comm_round, + top_model_accuracy, + NUMBER_OF_MALICIOUS_CLIENTS / TOTAL_CLIENT_NUMBER, + 'CrowdGuard' + ) + flflow.runtime = local_runtime + flflow.run() diff --git a/openfl-tutorials/experimental/CrowdGuard/readme.md b/openfl-tutorials/experimental/CrowdGuard/readme.md new file mode 100644 index 00000000000..2cf614ffcd5 --- /dev/null +++ b/openfl-tutorials/experimental/CrowdGuard/readme.md @@ -0,0 +1,23 @@ +# On the Integration of CrowdGuard into OpenFL +Federated Learning (FL) is a promising approach enabling multiple clients to train Deep Neural Networks (DNNs) collaboratively without sharing their local training data. However, FL is susceptible to backdoor (or targeted poisoning) attacks. These attacks are initiated by malicious clients who seek to compromise the learning process by introducing specific behaviors into the learned model that can be triggered by carefully crafted inputs. Existing FL safeguards have various limitations: They are restricted to specific data distributions or reduce the global model accuracy due to excluding benign models or adding noise, are vulnerable to adaptive defense-aware adversaries, or require the server to access local models, allowing data inference attacks. + +This tutorial implements CrowdGuard [1], which effectively mitigates backdoor attacks in FL and overcomes the deficiencies of existing techniques. It leverages clients' feedback on individual models, analyzes the behavior of neurons in hidden layers, and eliminates poisoned models through an iterative pruning scheme. CrowdGuard employs a server-located stacked clustering scheme to enhance its resilience to rogue client feedback. The experiments that were conducted in the paper show a 100% True-Positive-Rate and True-Negative-Rate across various scenarios, including IID and non-IID data distributions. Additionally, CrowdGuard withstands adaptive adversaries while preserving the original performance of protected models. To ensure confidentiality, CrowdGuard requires a secure and privacy-preserving architecture leveraging Trusted Execution Environments (TEEs) on both client and server sides. Full instructions to set up CrowdGuard's workflows inside TEEs using the OpenFL Workflow API will be made available in a future release of OpenFL. + + + +## Threat Model +Following this, we consider two threat models. +- Backdoor Attacks: Malicious clients aim to inject a backdoor by uploading manipulated model updates. +- Privacy Attacks: The attacker aims to infer information about the clients' data from their local models. Thus, the server tries to gain access to the local models before their aggregation. The clients try to gain access to other clients' local models. + + +## Workflow +We provide a demo code in `cifar10_crowdguard.py` as well as an interactive version as notebook. In the following, we briefly describe the workflow. +In each FL training round, each client trains the global model using its local dataset. Afterward, the server collects the local models and sends them to the clients for the local validation. The clients report the identified suspicious models to the server, which combines these votes using the stacked-clustering scheme to identify the poisoned models. At the end of each round, the identified benign models are aggregated using FedAVG. + +## Methodology +We implemented a simple scaling-based poisoning attack to demonstrate the effectiveness of CrowdGuard. + +For the local validation in CrowdGuard, each client uses its local dataset to obtain the hidden layer outputs for each local model. Then it calculates the Euclidean and Cosine Distance, before applying a PCA. Based on the first principal component, CrowdGuard employs several statistical tests to determine whether poisoned models remain and removes the poisoned models using clustering. This process is repeated until no more poisoned models are detected before sending the detected poisoned models to the server. On the server side, the votes of the individual clients are aggregated using a stacked-clustering scheme to prevent malicious clients from manipulating the aggregation process through manipulated votes. The client-side validation as well as the server-side operations, are executed with SGX to prevent privacy attacks. + +[1] Rieger, P., Krauß, T., Miettinen, M., Dmitrienko, A., & Sadeghi, A. R. CrowdGuard: Federated Backdoor Detection in Federated Learning. NDSS 2024. \ No newline at end of file