diff --git a/.github/workflows/federated_runtime.yml b/.github/workflows/federated_runtime.yml index 717b96e490..ead1d2791c 100644 --- a/.github/workflows/federated_runtime.yml +++ b/.github/workflows/federated_runtime.yml @@ -1,6 +1,5 @@ #--------------------------------------------------------------------------- # Workflow to run 301_MNIST_Watermarking notebook -# Authors - Noopur, Payal Chaurasiya #--------------------------------------------------------------------------- name: Federated Runtime 301 MNIST Watermarking diff --git a/.github/workflows/task_runner_basic_e2e.yml b/.github/workflows/task_runner_basic_e2e.yml index 4c4aaa12d7..3b336e22af 100644 --- a/.github/workflows/task_runner_basic_e2e.yml +++ b/.github/workflows/task_runner_basic_e2e.yml @@ -34,8 +34,8 @@ jobs: timeout-minutes: 30 strategy: matrix: - # There are open issues for some of the models, so excluding them for now: - # model_name: [ "torch_cnn_mnist", "keras_cnn_mnist", "torch_cnn_histology" ] + # Models like XGBoost (xgb_higgs) and torch_cnn_histology require runners with higher memory and CPU to run. + # Thus these models are excluded from the matrix for now. model_name: ["torch_cnn_mnist", "keras_cnn_mnist"] python_version: ["3.10", "3.11", "3.12"] fail-fast: false # do not immediately fail if one of the combinations fail @@ -77,7 +77,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - # Testing non TLS scenario only for torch_cnn_mnist model and python 3.10 + # Testing this scenario only for torch_cnn_mnist model and python 3.10 # If required, this can be extended to other models and python versions model_name: ["torch_cnn_mnist"] python_version: ["3.10"] @@ -120,7 +120,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - # Testing non TLS scenario only for torch_cnn_mnist model and python 3.10 + # Testing this scenario for keras_cnn_mnist model and python 3.10 # If required, this can be extended to other models and python versions model_name: ["keras_cnn_mnist"] python_version: ["3.10"] @@ -163,7 +163,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - # Testing non TLS scenario only for torch_cnn_mnist model and python 3.10 + # Testing this scenario only for torch_cnn_mnist model and python 3.10 # If required, this can be extended to other models and python versions model_name: ["torch_cnn_mnist"] python_version: ["3.10"] diff --git a/.github/workflows/task_runner_dockerized_ws_e2e.yml b/.github/workflows/task_runner_dockerized_ws_e2e.yml index f48ae32471..ce677c5789 100644 --- a/.github/workflows/task_runner_dockerized_ws_e2e.yml +++ b/.github/workflows/task_runner_dockerized_ws_e2e.yml @@ -33,7 +33,7 @@ jobs: strategy: matrix: model_name: ["keras_cnn_mnist"] - python_version: ["3.9", "3.10", "3.11"] + python_version: ["3.10", "3.11", "3.12"] fail-fast: false # do not immediately fail if one of the combinations fail env: diff --git a/openfl-tutorials/experimental/workflow/CrowdGuard/CrowdGuardClientValidation.py b/openfl-tutorials/experimental/workflow/CrowdGuard/CrowdGuardClientValidation.py index 1e8d5e2c59..40dc869773 100644 --- a/openfl-tutorials/experimental/workflow/CrowdGuard/CrowdGuardClientValidation.py +++ b/openfl-tutorials/experimental/workflow/CrowdGuard/CrowdGuardClientValidation.py @@ -377,7 +377,7 @@ def __prune_poisoned_models(num_layers, total_number_of_clients, own_client_inde ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None, compute_full_tree=True, - affinity="euclidean", memory=None, + metric="euclidean", memory=None, connectivity=None, linkage='single', compute_distances=True).fit(cluster_input) diff --git a/openfl-tutorials/experimental/workflow/CrowdGuard/PoisoningAttackDemo.ipynb b/openfl-tutorials/experimental/workflow/CrowdGuard/PoisoningAttackDemo.ipynb index 56283d82b5..b4611a991f 100644 --- a/openfl-tutorials/experimental/workflow/CrowdGuard/PoisoningAttackDemo.ipynb +++ b/openfl-tutorials/experimental/workflow/CrowdGuard/PoisoningAttackDemo.ipynb @@ -430,9 +430,8 @@ " state_dicts = [model.state_dict() for model in models]\n", " state_dict = new_model.state_dict()\n", " for key in models[1].state_dict():\n", - " state_dict[key] = np.sum(\n", - " [state[key] for state in state_dicts], axis=0\n", - " ) / len(models)\n", + " state_dict[key] = torch.from_numpy(\n", + " np.average([state[key].numpy() for state in state_dicts], axis=0))\n", " new_model.load_state_dict(state_dict)\n", " return new_model\n", "\n", @@ -558,8 +557,7 @@ " exclude=[\"private\"],\n", " )\n", "\n", - " # @collaborator # Uncomment if you want ro run on CPU\n", - " @collaborator(num_gpus=1) # Assuming GPU(s) is available on the machine\n", + " @collaborator\n", " def train(self):\n", " self.collaborator_name = self.input\n", " print(20 * \"#\")\n", @@ -669,7 +667,7 @@ "\n", " ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None,\n", " compute_full_tree=True,\n", - " affinity=\"euclidean\", memory=None, connectivity=None,\n", + " metric=\"euclidean\", memory=None, connectivity=None,\n", " linkage='single',\n", " compute_distances=True).fit(binary_votes)\n", " ac_e_labels: list = ac_e.labels_.tolist()\n", diff --git a/openfl-tutorials/experimental/workflow/CrowdGuard/cifar10_crowdguard.py b/openfl-tutorials/experimental/workflow/CrowdGuard/cifar10_crowdguard.py index bc147c1946..559d6f32f7 100644 --- a/openfl-tutorials/experimental/workflow/CrowdGuard/cifar10_crowdguard.py +++ b/openfl-tutorials/experimental/workflow/CrowdGuard/cifar10_crowdguard.py @@ -220,9 +220,8 @@ def FedAvg(models): # NOQA: N802 state_dicts = [model.state_dict() for model in models] state_dict = new_model.state_dict() for key in models[1].state_dict(): - state_dict[key] = np.sum( - [state[key] for state in state_dicts], axis=0 - ) / len(models) + state_dict[key] = torch.from_numpy( + np.average([state[key].numpy() for state in state_dicts], axis=0)) new_model.load_state_dict(state_dict) return new_model @@ -316,8 +315,7 @@ def start(self): exclude=["private"], ) - # @collaborator # Uncomment if you want ro run on CPU - @collaborator(num_gpus=1) # Assuming GPU(s) is available on the machine + @collaborator def train(self): self.collaborator_name = self.input print(20 * "#") @@ -428,7 +426,7 @@ def defend(self, inputs): ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None, compute_full_tree=True, - affinity="euclidean", memory=None, connectivity=None, + metric="euclidean", memory=None, connectivity=None, linkage='single', compute_distances=True).fit(binary_votes) ac_e_labels: list = ac_e.labels_.tolist() diff --git a/openfl-tutorials/experimental/workflow/CrowdGuard/readme.md b/openfl-tutorials/experimental/workflow/CrowdGuard/readme.md index 2cf614ffcd..d252e93073 100644 --- a/openfl-tutorials/experimental/workflow/CrowdGuard/readme.md +++ b/openfl-tutorials/experimental/workflow/CrowdGuard/readme.md @@ -20,4 +20,29 @@ We implemented a simple scaling-based poisoning attack to demonstrate the effect For the local validation in CrowdGuard, each client uses its local dataset to obtain the hidden layer outputs for each local model. Then it calculates the Euclidean and Cosine Distance, before applying a PCA. Based on the first principal component, CrowdGuard employs several statistical tests to determine whether poisoned models remain and removes the poisoned models using clustering. This process is repeated until no more poisoned models are detected before sending the detected poisoned models to the server. On the server side, the votes of the individual clients are aggregated using a stacked-clustering scheme to prevent malicious clients from manipulating the aggregation process through manipulated votes. The client-side validation as well as the server-side operations, are executed with SGX to prevent privacy attacks. -[1] Rieger, P., Krauß, T., Miettinen, M., Dmitrienko, A., & Sadeghi, A. R. CrowdGuard: Federated Backdoor Detection in Federated Learning. NDSS 2024. \ No newline at end of file +[1] Rieger, P., Krauß, T., Miettinen, M., Dmitrienko, A., & Sadeghi, A. R. CrowdGuard: Federated Backdoor Detection in Federated Learning. NDSS 2024. + +## Running the CIFAR-10 demo script +The demo script requires a dedicated allocation of at least 18GB of RAM to run without issues. + +1) Create a Python virtual environment for better isolation +```shell +python -m venv venv +source venv/bin/activate +``` +2) Install OpenFL from the latest sources +```shell +git clone https://github.com/securefederatedai/openfl.git && cd openfl +pip install -e . +``` +3) Install the requirements for Workflow API +```shell +cd openfl-tutorials/experimental/workflow +pip install -r workflow_interface_requirements.txt +``` +4) Start the training script
+Note that the number of training rounds can be adjusted via the `--comm_round` parameter: +```shell +cd CrowdGuard +python cifar10_crowdguard.py --comm_round 5 +``` \ No newline at end of file diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index c6829e75b9..d51ea68291 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -352,6 +352,7 @@ def _save_model(self, round_number, file_path): ] tensor_dict = {} tensor_tuple_dict = {} + next_round_tensors = {} for tk in tensor_keys: tk_name, _, _, _, _ = tk tensor_value = self.tensor_db.get_tensor_from_cache(tk) @@ -373,7 +374,7 @@ def _save_model(self, round_number, file_path): self.next_model_round_number, ("model",) ) self.persistent_db.finalize_round( - tensor_tuple_dict, next_round_tensors, self.round_number, self.best_model_score + tensor_tuple_dict, next_round_tensors, round_number, self.best_model_score ) logger.info( "Persist model and clean task result for round %s", diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index 6d1ad45b75..05bdb74d1c 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -12,6 +12,7 @@ from tests.end_to_end.utils.logger import configure_logging from tests.end_to_end.utils.logger import logger as log from tests.end_to_end.utils.conftest_helper import parse_arguments +import tests.end_to_end.utils.docker_helper as dh def pytest_addoption(parser): @@ -192,6 +193,11 @@ def pytest_sessionfinish(session, exitstatus): shutil.rmtree(cache_dir, ignore_errors=False) log.debug(f"Cleared .pytest_cache directory at {cache_dir}") + # Cleanup docker containers related to aggregator and collaborators, if any. + dh.cleanup_docker_containers(list_of_containers=["aggregator", "collaborator*"]) + # Cleanup docker network created for openfl, if any. + dh.remove_docker_network(["openfl"]) + def pytest_configure(config): """ diff --git a/tests/end_to_end/models/collaborator.py b/tests/end_to_end/models/collaborator.py index 51cbfeba0a..dd9389128b 100644 --- a/tests/end_to_end/models/collaborator.py +++ b/tests/end_to_end/models/collaborator.py @@ -5,9 +5,9 @@ import logging import tests.end_to_end.utils.docker_helper as dh +import tests.end_to_end.utils.exceptions as ex import tests.end_to_end.utils.federation_helper as fh - log = logging.getLogger(__name__) @@ -205,3 +205,23 @@ def import_workspace(self): except Exception as e: log.error(f"{error_msg}: {e}") raise e + + def modify_data_file(self, data_file, index): + """ + Modify the data.yaml file for the model + Args: + data_file (str): Path to the data file including the file name + Returns: + bool: True if successful, else False + """ + try: + log.info("Data setup completed successfully. Modifying the data.yaml file..") + + with open(data_file, "w") as file: + file.write(f"{self.collaborator_name},data/{index}") + + except Exception as e: + log.error(f"Failed to modify the data file: {e}") + raise ex.DataSetupException(f"Failed to modify the data file: {e}") + + return True diff --git a/tests/end_to_end/utils/constants.py b/tests/end_to_end/utils/constants.py index a7cbb29312..a9f66e47ef 100644 --- a/tests/end_to_end/utils/constants.py +++ b/tests/end_to_end/utils/constants.py @@ -13,6 +13,7 @@ class ModelName(Enum): TORCH_CNN_MNIST = "torch_cnn_mnist" KERAS_CNN_MNIST = "keras_cnn_mnist" TORCH_CNN_HISTOLOGY = "torch_cnn_histology" + XGB_HIGGS = "xgb_higgs" NUM_COLLABORATORS = 2 NUM_ROUNDS = 5 @@ -31,6 +32,10 @@ class ModelName(Enum): AGG_PLAN_PATH = "{}/aggregator/workspace/plan" # example - /tmp/my_federation/aggregator/workspace/plan COL_PLAN_PATH = "{}/{}/workspace/plan" # example - /tmp/my_federation/collaborator1/workspace/plan +COL_DATA_FILE = "{}/{}/workspace/plan/data.yaml" # example - /tmp/my_federation/collaborator1/workspace/plan/data.yaml + +DATA_SETUP_FILE = "setup_data.py" # currently xgb_higgs is using this file to setup data + AGG_COL_RESULT_FILE = "{0}/{1}/workspace/{1}.log" # example - /tmp/my_federation/aggregator/workspace/aggregator.log AGG_WORKSPACE_ZIP_NAME = "workspace.zip" diff --git a/tests/end_to_end/utils/docker_helper.py b/tests/end_to_end/utils/docker_helper.py index bfb8c214cb..d8178652c0 100644 --- a/tests/end_to_end/utils/docker_helper.py +++ b/tests/end_to_end/utils/docker_helper.py @@ -12,35 +12,40 @@ log = logging.getLogger(__name__) -def remove_docker_network(): +def remove_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]): """ Remove docker network. + Args: + list_of_networks (list): List of network names to remove. """ client = get_docker_client() - networks = client.networks.list(names=[constants.DOCKER_NETWORK_NAME]) + networks = client.networks.list(names=list_of_networks) if not networks: - log.debug(f"Network {constants.DOCKER_NETWORK_NAME} does not exist") + log.debug(f"Network(s) {list_of_networks} does not exist") return for network in networks: log.debug(f"Removing network: {network.name}") network.remove() - log.debug("Docker network removed successfully") + log.debug(f"Docker network(s) {list_of_networks} removed successfully") -def create_docker_network(): +def create_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]): """ Create docker network. + Args: + list_of_networks (list): List of network names to create. """ client = get_docker_client() - networks = client.networks.list(names=[constants.DOCKER_NETWORK_NAME]) + networks = client.networks.list(names=list_of_networks) if networks: - log.info(f"Network {constants.DOCKER_NETWORK_NAME} already exists") + log.info(f"Network(s) {list_of_networks} already exists") return - log.debug(f"Creating network: {constants.DOCKER_NETWORK_NAME}") - network = client.networks.create(constants.DOCKER_NETWORK_NAME) - log.info(f"Network {network.name} created successfully") + for network_name in list_of_networks: + log.debug(f"Creating network: {network_name}") + _ = client.networks.create(network_name) + log.info(f"Docker network(s) {list_of_networks} created successfully") def check_docker_image(): @@ -143,24 +148,24 @@ def get_docker_client(): return client -def cleanup_docker_containers(): +def cleanup_docker_containers(list_of_containers=["aggregator", "collaborator*"]): """ Cleanup the docker containers meant for openfl. + Args: + list_of_containers: List of container names to cleanup. """ log.debug("Cleaning up docker containers") client = get_docker_client() - # List all containers related to openfl - agg_containers = client.containers.list(all=True, filters={"name": "aggregator"}) - col_containers = client.containers.list(all=True, filters={"name": "collaborator*"}) - containers = agg_containers + col_containers - container_names = [] - # Stop and remove all containers - for container in containers: - container.stop() - container.remove() - container_names.append(container.name) - - if containers: - log.info(f"Docker containers {container_names} cleaned up successfully") + for container_name in list_of_containers: + containers = client.containers.list(all=True, filters={"name": container_name}) + container_names = [] + # Stop and remove all containers + for container in containers: + container.stop() + container.remove() + container_names.append(container.name) + + if containers: + log.info(f"Docker containers {container_names} cleaned up successfully") diff --git a/tests/end_to_end/utils/exceptions.py b/tests/end_to_end/utils/exceptions.py index 31fa596ac0..2a12842080 100644 --- a/tests/end_to_end/utils/exceptions.py +++ b/tests/end_to_end/utils/exceptions.py @@ -86,3 +86,8 @@ class EnvoyStartException(Exception): class DirectorStartException(Exception): """Exception for director start""" pass + + +class DataSetupException(Exception): + """Exception for data setup for given model""" + pass diff --git a/tests/end_to_end/utils/federation_helper.py b/tests/end_to_end/utils/federation_helper.py index 50910c4f2e..9bc3a4dcd1 100644 --- a/tests/end_to_end/utils/federation_helper.py +++ b/tests/end_to_end/utils/federation_helper.py @@ -7,8 +7,10 @@ import os import json import re +import subprocess # nosec B404 import papermill as pm from pathlib import Path +import shutil import tests.end_to_end.utils.constants as constants import tests.end_to_end.utils.docker_helper as dh @@ -110,18 +112,24 @@ def setup_pki_for_collaborators(collaborators, model_owner, local_bind_path): return True -def create_tarball_for_collaborators(collaborators, local_bind_path, use_tls): +def create_tarball_for_collaborators(collaborators, local_bind_path, use_tls, add_data=False): """ Create tarball for all the collaborators Args: collaborators (list): List of collaborator objects local_bind_path (str): Local bind path use_tls (bool): Use TLS or not (default is True) + add_data (bool): Add data to the tarball (default is False) """ executor = concurrent.futures.ThreadPoolExecutor() try: - def _create_tarball(collaborator_name, local_bind_path): + def _create_tarball(collaborator_name, data_file_path, local_bind_path, add_data): + """ + Internal function to create tarball for the collaborator. + If TLS is enabled - include client certificates and signed certificates in the tarball + If data needs to be added - include the data file in the tarball + """ local_col_ws_path = constants.COL_WORKSPACE_PATH.format( local_bind_path, collaborator_name ) @@ -134,7 +142,11 @@ def _create_tarball(collaborator_name, local_bind_path): ] client_certs = " ".join(client_cert_entries) if client_cert_entries else "" tarfiles += f" agg_to_col_{collaborator_name}_signed_cert.zip {client_certs}" + # IMPORTANT: Model XGBoost(xgb_higgs) uses format like data/1 and data/2, thus adding data to tarball in the same format. + if add_data: + tarfiles += f" data/{data_file_path}" + log.info(f"Tarfile for {collaborator_name} includes: {tarfiles}") return_code, output, error = ssh.run_command( f"tar -cf {tarfiles}", work_dir=local_col_ws_path ) @@ -146,9 +158,9 @@ def _create_tarball(collaborator_name, local_bind_path): results = [ executor.submit( - _create_tarball, collaborator.name, local_bind_path=local_bind_path + _create_tarball, collaborator.name, data_file_path=index, local_bind_path=local_bind_path, add_data=add_data ) - for collaborator in collaborators + for index, collaborator in enumerate(collaborators, start=1) ] if not all([f.result() for f in results]): raise Exception("Failed to create tarball for one or more collaborators") @@ -629,18 +641,22 @@ def verify_cmd_output( raise Exception(f"{error_msg}: {error}") -def setup_collaborator(count, workspace_path, local_bind_path): +def setup_collaborator(index, workspace_path, local_bind_path): """ Setup the collaborator Includes - creation of collaborator objects, starting docker container, importing workspace, creating collaborator + Args: + index (int): Index of the collaborator. Starts with 1. + workspace_path (str): Workspace path + local_bind_path (str): Local bind path """ local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path) try: collaborator = col_model.Collaborator( - collaborator_name=f"collaborator{count+1}", - data_directory_path=count + 1, - workspace_path=f"{workspace_path}/collaborator{count+1}/workspace", + collaborator_name=f"collaborator{index}", + data_directory_path=index, + workspace_path=f"{workspace_path}/collaborator{index}/workspace", ) create_persistent_store(collaborator.name, local_bind_path) @@ -670,6 +686,80 @@ def setup_collaborator(count, workspace_path, local_bind_path): return collaborator +def setup_collaborator_data(collaborators, model_name, local_bind_path): + """ + Function to setup the data for collaborators. + IMP: This function is specific to the model and should be updated as per the model requirements. + Args: + collaborators (list): List of collaborator objects + model_name (str): Model name + local_bind_path (str): Local bind path + """ + # Check if data already exists, if yes, skip the download part + # This is mainly helpful in case of re-runs + if all(os.path.exists(os.path.join(collaborator.workspace_path, "data", str(index))) for index, collaborator in enumerate(collaborators, start=1)): + log.info("Data already exists for all the collaborators. Skipping the download part..") + return + else: + log.info("Data does not exist for all the collaborators. Proceeding with the download..") + # Below step will also modify the data.yaml file for all the collaborators + download_data(collaborators, model_name, local_bind_path) + + log.info("Data setup is complete for all the collaborators") + + +def download_data(collaborators, model_name, local_bind_path): + """ + Download the data for the model and copy to the respective collaborator workspaces + Also modify the data.yaml file for all the collaborators + Args: + collaborators (list): List of collaborator objects + model_name (str): Model name + local_bind_path (str): Local bind path + Returns: + bool: True if successful, else False + """ + log.info(f"Copying {constants.DATA_SETUP_FILE} from one of the collaborator workspaces to the local bind path..") + try: + shutil.copyfile( + src=os.path.join(collaborators[0].workspace_path, "src", constants.DATA_SETUP_FILE), + dst=os.path.join(local_bind_path, constants.DATA_SETUP_FILE) + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to copy data setup file: {e}") + + log.info("Downloading the data for the model. This will take some time to complete based on the data size ..") + try: + command = ["python", constants.DATA_SETUP_FILE, str(len(collaborators))] + subprocess.run(command, cwd=local_bind_path, check=True) + except Exception: + raise ex.DataSetupException(f"Failed to download data for {model_name}") + + try: + # Copy the data to the respective workspaces based on the index + for index, collaborator in enumerate(collaborators, start=1): + src_folder = os.path.join(local_bind_path, "data", str(index)) + dst_folder = os.path.join(collaborator.workspace_path, "data", str(index)) + if os.path.exists(src_folder): + shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True) + log.info(f"Copied data from {src_folder} to {dst_folder}") + else: + raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}") + + # Modify the data.yaml file for all the collaborators + collaborator.modify_data_file( + constants.COL_DATA_FILE.format(local_bind_path, collaborator.name), + index, + ) + except Exception as e: + raise ex.DataSetupException(f"Failed to modify the data file: {e}") + + # Below step is specific to XGBoost model which uses higgs_data folder to create data folders. + shutil.rmtree(os.path.join(local_bind_path, "higgs_data"), ignore_errors=True) + + return True + + def extract_memory_usage(log_file): """ Extracts memory usage data from a log file. diff --git a/tests/end_to_end/utils/tr_common_fixtures.py b/tests/end_to_end/utils/tr_common_fixtures.py index 593e57e473..88e02e631c 100644 --- a/tests/end_to_end/utils/tr_common_fixtures.py +++ b/tests/end_to_end/utils/tr_common_fixtures.py @@ -92,14 +92,18 @@ def fx_federation_tr(request): futures = [ executor.submit( fh.setup_collaborator, - count=i, + index, workspace_path=workspace_path, local_bind_path=local_bind_path, ) - for i in range(request.config.num_collaborators) + for index in range(1, request.config.num_collaborators+1) ] collaborators = [f.result() for f in futures] + # Data setup requires total no of collaborators, thus keeping the function call outside of the loop + if model_name.lower() == "xgb_higgs": + fh.setup_collaborator_data(collaborators, model_name, local_bind_path) + if request.config.use_tls: fh.setup_pki_for_collaborators(collaborators, model_owner, local_bind_path) fh.import_pki_for_collaborators(collaborators, local_bind_path) @@ -181,20 +185,25 @@ def fx_federation_tr_dws(request): futures = [ executor.submit( fh.setup_collaborator, - count=i, + index, workspace_path=workspace_path, local_bind_path=local_bind_path, ) - for i in range(request.config.num_collaborators) + for index in range(1, request.config.num_collaborators+1) ] collaborators = [f.result() for f in futures] if request.config.use_tls: fh.setup_pki_for_collaborators(collaborators, model_owner, local_bind_path) + # Data setup requires total no of collaborators, thus keeping the function call outside of the loop + if model_name.lower() == "xgb_higgs": + fh.setup_collaborator_data(collaborators, model_name, local_bind_path) + # Note: In case of multiple machines setup, scp the created tar for collaborators to the other machine(s) fh.create_tarball_for_collaborators( - collaborators, local_bind_path, use_tls=request.config.use_tls + collaborators, local_bind_path, use_tls=request.config.use_tls, + add_data=True if model_name.lower() == "xgb_higgs" else False ) # Generate the sign request and certify the aggregator in case of TLS