Skip to content

Commit

Permalink
E2E automation support for XGBoost model (#1288)
Browse files Browse the repository at this point in the history
* Task runner automation for xgb_higgs

Signed-off-by: noopur <[email protected]>

* Initial draft

Signed-off-by: noopur <[email protected]>

* Post bare metal and dockerized approach testing

Signed-off-by: noopur <[email protected]>

* Minor corrections and reverts

Signed-off-by: noopur <[email protected]>

* Minor corrections and reverts

Signed-off-by: noopur <[email protected]>

* Minor changes

Signed-off-by: noopur <[email protected]>

* Modularized data download function

Signed-off-by: noopur <[email protected]>

* Remove containers and networks in pytest_sessionfinish

Signed-off-by: noopur <[email protected]>

* Removed 3.9, added 3.12 in dockerized workflow

Signed-off-by: noopur <[email protected]>

* Code format check

Signed-off-by: noopur <[email protected]>

* Bandit issue resolved

Signed-off-by: noopur <[email protected]>

* More bandit issues, nosec added

Signed-off-by: noopur <[email protected]>

* Bandit issue - subprocess corrected

Signed-off-by: noopur <[email protected]>

* Review comments incorporated

Signed-off-by: noopur <[email protected]>

* Code format check

Signed-off-by: noopur <[email protected]>

* Minor comment correction in task runner workflow

Signed-off-by: noopur <[email protected]>

---------

Signed-off-by: noopur <[email protected]>
  • Loading branch information
noopurintel authored Jan 22, 2025
1 parent 8104144 commit 4a1a135
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 45 deletions.
1 change: 0 additions & 1 deletion .github/workflows/federated_runtime.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#---------------------------------------------------------------------------
# Workflow to run 301_MNIST_Watermarking notebook
# Authors - Noopur, Payal Chaurasiya
#---------------------------------------------------------------------------
name: Federated Runtime 301 MNIST Watermarking

Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/task_runner_basic_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
# There are open issues for some of the models, so excluding them for now:
# model_name: [ "torch_cnn_mnist", "keras_cnn_mnist", "torch_cnn_histology" ]
# Models like XGBoost (xgb_higgs) and torch_cnn_histology require runners with higher memory and CPU to run.
# Thus these models are excluded from the matrix for now.
model_name: ["torch_cnn_mnist", "keras_cnn_mnist"]
python_version: ["3.10", "3.11", "3.12"]
fail-fast: false # do not immediately fail if one of the combinations fail
Expand Down Expand Up @@ -77,7 +77,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
# Testing non TLS scenario only for torch_cnn_mnist model and python 3.10
# Testing this scenario only for torch_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["torch_cnn_mnist"]
python_version: ["3.10"]
Expand Down Expand Up @@ -120,7 +120,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
# Testing non TLS scenario only for torch_cnn_mnist model and python 3.10
# Testing this scenario for keras_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["keras_cnn_mnist"]
python_version: ["3.10"]
Expand Down Expand Up @@ -163,7 +163,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
# Testing non TLS scenario only for torch_cnn_mnist model and python 3.10
# Testing this scenario only for torch_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["torch_cnn_mnist"]
python_version: ["3.10"]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/task_runner_dockerized_ws_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
strategy:
matrix:
model_name: ["keras_cnn_mnist"]
python_version: ["3.9", "3.10", "3.11"]
python_version: ["3.10", "3.11", "3.12"]
fail-fast: false # do not immediately fail if one of the combinations fail

env:
Expand Down
6 changes: 6 additions & 0 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tests.end_to_end.utils.logger import configure_logging
from tests.end_to_end.utils.logger import logger as log
from tests.end_to_end.utils.conftest_helper import parse_arguments
import tests.end_to_end.utils.docker_helper as dh


def pytest_addoption(parser):
Expand Down Expand Up @@ -192,6 +193,11 @@ def pytest_sessionfinish(session, exitstatus):
shutil.rmtree(cache_dir, ignore_errors=False)
log.debug(f"Cleared .pytest_cache directory at {cache_dir}")

# Cleanup docker containers related to aggregator and collaborators, if any.
dh.cleanup_docker_containers(list_of_containers=["aggregator", "collaborator*"])
# Cleanup docker network created for openfl, if any.
dh.remove_docker_network(["openfl"])


def pytest_configure(config):
"""
Expand Down
22 changes: 21 additions & 1 deletion tests/end_to_end/models/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging

import tests.end_to_end.utils.docker_helper as dh
import tests.end_to_end.utils.exceptions as ex
import tests.end_to_end.utils.federation_helper as fh


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -205,3 +205,23 @@ def import_workspace(self):
except Exception as e:
log.error(f"{error_msg}: {e}")
raise e

def modify_data_file(self, data_file, index):
"""
Modify the data.yaml file for the model
Args:
data_file (str): Path to the data file including the file name
Returns:
bool: True if successful, else False
"""
try:
log.info("Data setup completed successfully. Modifying the data.yaml file..")

with open(data_file, "w") as file:
file.write(f"{self.collaborator_name},data/{index}")

except Exception as e:
log.error(f"Failed to modify the data file: {e}")
raise ex.DataSetupException(f"Failed to modify the data file: {e}")

return True
5 changes: 5 additions & 0 deletions tests/end_to_end/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ModelName(Enum):
TORCH_CNN_MNIST = "torch_cnn_mnist"
KERAS_CNN_MNIST = "keras_cnn_mnist"
TORCH_CNN_HISTOLOGY = "torch_cnn_histology"
XGB_HIGGS = "xgb_higgs"

NUM_COLLABORATORS = 2
NUM_ROUNDS = 5
Expand All @@ -31,6 +32,10 @@ class ModelName(Enum):
AGG_PLAN_PATH = "{}/aggregator/workspace/plan" # example - /tmp/my_federation/aggregator/workspace/plan
COL_PLAN_PATH = "{}/{}/workspace/plan" # example - /tmp/my_federation/collaborator1/workspace/plan

COL_DATA_FILE = "{}/{}/workspace/plan/data.yaml" # example - /tmp/my_federation/collaborator1/workspace/plan/data.yaml

DATA_SETUP_FILE = "setup_data.py" # currently xgb_higgs is using this file to setup data

AGG_COL_RESULT_FILE = "{0}/{1}/workspace/{1}.log" # example - /tmp/my_federation/aggregator/workspace/aggregator.log

AGG_WORKSPACE_ZIP_NAME = "workspace.zip"
Expand Down
53 changes: 29 additions & 24 deletions tests/end_to_end/utils/docker_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,40 @@
log = logging.getLogger(__name__)


def remove_docker_network():
def remove_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]):
"""
Remove docker network.
Args:
list_of_networks (list): List of network names to remove.
"""
client = get_docker_client()
networks = client.networks.list(names=[constants.DOCKER_NETWORK_NAME])
networks = client.networks.list(names=list_of_networks)
if not networks:
log.debug(f"Network {constants.DOCKER_NETWORK_NAME} does not exist")
log.debug(f"Network(s) {list_of_networks} does not exist")
return

for network in networks:
log.debug(f"Removing network: {network.name}")
network.remove()
log.debug("Docker network removed successfully")
log.debug(f"Docker network(s) {list_of_networks} removed successfully")


def create_docker_network():
def create_docker_network(list_of_networks=[constants.DOCKER_NETWORK_NAME]):
"""
Create docker network.
Args:
list_of_networks (list): List of network names to create.
"""
client = get_docker_client()
networks = client.networks.list(names=[constants.DOCKER_NETWORK_NAME])
networks = client.networks.list(names=list_of_networks)
if networks:
log.info(f"Network {constants.DOCKER_NETWORK_NAME} already exists")
log.info(f"Network(s) {list_of_networks} already exists")
return

log.debug(f"Creating network: {constants.DOCKER_NETWORK_NAME}")
network = client.networks.create(constants.DOCKER_NETWORK_NAME)
log.info(f"Network {network.name} created successfully")
for network_name in list_of_networks:
log.debug(f"Creating network: {network_name}")
_ = client.networks.create(network_name)
log.info(f"Docker network(s) {list_of_networks} created successfully")


def check_docker_image():
Expand Down Expand Up @@ -143,24 +148,24 @@ def get_docker_client():
return client


def cleanup_docker_containers():
def cleanup_docker_containers(list_of_containers=["aggregator", "collaborator*"]):
"""
Cleanup the docker containers meant for openfl.
Args:
list_of_containers: List of container names to cleanup.
"""
log.debug("Cleaning up docker containers")

client = get_docker_client()

# List all containers related to openfl
agg_containers = client.containers.list(all=True, filters={"name": "aggregator"})
col_containers = client.containers.list(all=True, filters={"name": "collaborator*"})
containers = agg_containers + col_containers
container_names = []
# Stop and remove all containers
for container in containers:
container.stop()
container.remove()
container_names.append(container.name)

if containers:
log.info(f"Docker containers {container_names} cleaned up successfully")
for container_name in list_of_containers:
containers = client.containers.list(all=True, filters={"name": container_name})
container_names = []
# Stop and remove all containers
for container in containers:
container.stop()
container.remove()
container_names.append(container.name)

if containers:
log.info(f"Docker containers {container_names} cleaned up successfully")
5 changes: 5 additions & 0 deletions tests/end_to_end/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,8 @@ class EnvoyStartException(Exception):
class DirectorStartException(Exception):
"""Exception for director start"""
pass


class DataSetupException(Exception):
"""Exception for data setup for given model"""
pass
106 changes: 98 additions & 8 deletions tests/end_to_end/utils/federation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import os
import json
import re
import subprocess # nosec B404
import papermill as pm
from pathlib import Path
import shutil

import tests.end_to_end.utils.constants as constants
import tests.end_to_end.utils.docker_helper as dh
Expand Down Expand Up @@ -110,18 +112,24 @@ def setup_pki_for_collaborators(collaborators, model_owner, local_bind_path):
return True


def create_tarball_for_collaborators(collaborators, local_bind_path, use_tls):
def create_tarball_for_collaborators(collaborators, local_bind_path, use_tls, add_data=False):
"""
Create tarball for all the collaborators
Args:
collaborators (list): List of collaborator objects
local_bind_path (str): Local bind path
use_tls (bool): Use TLS or not (default is True)
add_data (bool): Add data to the tarball (default is False)
"""
executor = concurrent.futures.ThreadPoolExecutor()
try:

def _create_tarball(collaborator_name, local_bind_path):
def _create_tarball(collaborator_name, data_file_path, local_bind_path, add_data):
"""
Internal function to create tarball for the collaborator.
If TLS is enabled - include client certificates and signed certificates in the tarball
If data needs to be added - include the data file in the tarball
"""
local_col_ws_path = constants.COL_WORKSPACE_PATH.format(
local_bind_path, collaborator_name
)
Expand All @@ -134,7 +142,11 @@ def _create_tarball(collaborator_name, local_bind_path):
]
client_certs = " ".join(client_cert_entries) if client_cert_entries else ""
tarfiles += f" agg_to_col_{collaborator_name}_signed_cert.zip {client_certs}"
# IMPORTANT: Model XGBoost(xgb_higgs) uses format like data/1 and data/2, thus adding data to tarball in the same format.
if add_data:
tarfiles += f" data/{data_file_path}"

log.info(f"Tarfile for {collaborator_name} includes: {tarfiles}")
return_code, output, error = ssh.run_command(
f"tar -cf {tarfiles}", work_dir=local_col_ws_path
)
Expand All @@ -146,9 +158,9 @@ def _create_tarball(collaborator_name, local_bind_path):

results = [
executor.submit(
_create_tarball, collaborator.name, local_bind_path=local_bind_path
_create_tarball, collaborator.name, data_file_path=index, local_bind_path=local_bind_path, add_data=add_data
)
for collaborator in collaborators
for index, collaborator in enumerate(collaborators, start=1)
]
if not all([f.result() for f in results]):
raise Exception("Failed to create tarball for one or more collaborators")
Expand Down Expand Up @@ -629,18 +641,22 @@ def verify_cmd_output(
raise Exception(f"{error_msg}: {error}")


def setup_collaborator(count, workspace_path, local_bind_path):
def setup_collaborator(index, workspace_path, local_bind_path):
"""
Setup the collaborator
Includes - creation of collaborator objects, starting docker container, importing workspace, creating collaborator
Args:
index (int): Index of the collaborator. Starts with 1.
workspace_path (str): Workspace path
local_bind_path (str): Local bind path
"""
local_agg_ws_path = constants.AGG_WORKSPACE_PATH.format(local_bind_path)

try:
collaborator = col_model.Collaborator(
collaborator_name=f"collaborator{count+1}",
data_directory_path=count + 1,
workspace_path=f"{workspace_path}/collaborator{count+1}/workspace",
collaborator_name=f"collaborator{index}",
data_directory_path=index,
workspace_path=f"{workspace_path}/collaborator{index}/workspace",
)
create_persistent_store(collaborator.name, local_bind_path)

Expand Down Expand Up @@ -670,6 +686,80 @@ def setup_collaborator(count, workspace_path, local_bind_path):
return collaborator


def setup_collaborator_data(collaborators, model_name, local_bind_path):
"""
Function to setup the data for collaborators.
IMP: This function is specific to the model and should be updated as per the model requirements.
Args:
collaborators (list): List of collaborator objects
model_name (str): Model name
local_bind_path (str): Local bind path
"""
# Check if data already exists, if yes, skip the download part
# This is mainly helpful in case of re-runs
if all(os.path.exists(os.path.join(collaborator.workspace_path, "data", str(index))) for index, collaborator in enumerate(collaborators, start=1)):
log.info("Data already exists for all the collaborators. Skipping the download part..")
return
else:
log.info("Data does not exist for all the collaborators. Proceeding with the download..")
# Below step will also modify the data.yaml file for all the collaborators
download_data(collaborators, model_name, local_bind_path)

log.info("Data setup is complete for all the collaborators")


def download_data(collaborators, model_name, local_bind_path):
"""
Download the data for the model and copy to the respective collaborator workspaces
Also modify the data.yaml file for all the collaborators
Args:
collaborators (list): List of collaborator objects
model_name (str): Model name
local_bind_path (str): Local bind path
Returns:
bool: True if successful, else False
"""
log.info(f"Copying {constants.DATA_SETUP_FILE} from one of the collaborator workspaces to the local bind path..")
try:
shutil.copyfile(
src=os.path.join(collaborators[0].workspace_path, "src", constants.DATA_SETUP_FILE),
dst=os.path.join(local_bind_path, constants.DATA_SETUP_FILE)
)
except Exception as e:
raise ex.DataSetupException(f"Failed to copy data setup file: {e}")

log.info("Downloading the data for the model. This will take some time to complete based on the data size ..")
try:
command = ["python", constants.DATA_SETUP_FILE, str(len(collaborators))]
subprocess.run(command, cwd=local_bind_path, check=True)
except Exception:
raise ex.DataSetupException(f"Failed to download data for {model_name}")

try:
# Copy the data to the respective workspaces based on the index
for index, collaborator in enumerate(collaborators, start=1):
src_folder = os.path.join(local_bind_path, "data", str(index))
dst_folder = os.path.join(collaborator.workspace_path, "data", str(index))
if os.path.exists(src_folder):
shutil.copytree(src_folder, dst_folder, dirs_exist_ok=True)
log.info(f"Copied data from {src_folder} to {dst_folder}")
else:
raise ex.DataSetupException(f"Source folder {src_folder} does not exist for {collaborator.name}")

# Modify the data.yaml file for all the collaborators
collaborator.modify_data_file(
constants.COL_DATA_FILE.format(local_bind_path, collaborator.name),
index,
)
except Exception as e:
raise ex.DataSetupException(f"Failed to modify the data file: {e}")

# Below step is specific to XGBoost model which uses higgs_data folder to create data folders.
shutil.rmtree(os.path.join(local_bind_path, "higgs_data"), ignore_errors=True)

return True


def extract_memory_usage(log_file):
"""
Extracts memory usage data from a log file.
Expand Down
Loading

0 comments on commit 4a1a135

Please sign in to comment.