Skip to content

Commit

Permalink
Merge branch 'develop' into torch-update
Browse files Browse the repository at this point in the history
  • Loading branch information
tanwarsh authored Jan 23, 2025
2 parents a414b22 + 22e7dab commit 4fbecc0
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 30 deletions.
136 changes: 113 additions & 23 deletions .github/workflows/task_runner_basic_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,116 @@ on:
required: false
default: "2"
type: string
model_name:
description: "Model name"
required: false
default: "all"
type: choice
options:
- all
- torch_cnn_mnist
- keras_cnn_mnist
python_version:
description: "Python version"
required: false
default: "all"
type: choice
options:
- all
- "3.10"
- "3.11"
- "3.12"
jobs_to_run:
description: "Jobs to run"
type: choice
default: "all"
options:
- all
- test_with_tls
- test_with_non_tls
- test_with_no_client_auth
- test_memory_logs
required: false

permissions:
contents: read

# Environment variables common for all the jobs
# DO NOT use double quotes for the values of the environment variables
env:
NUM_ROUNDS: ${{ inputs.num_rounds || '5' }}
NUM_COLLABORATORS: ${{ inputs.num_collaborators || '2' }}
NUM_ROUNDS: ${{ inputs.num_rounds || 5 }}
NUM_COLLABORATORS: ${{ inputs.num_collaborators || 2 }}
MODEL_NAME: ${{ inputs.model_name || 'all' }}
PYTHON_VERSION: ${{ inputs.python_version || 'all' }}
JOBS_TO_RUN: ${{ inputs.jobs_to_run || 'all' }}

jobs:
input_selection:
if: |
(github.event_name == 'schedule' && github.repository_owner == 'securefederatedai') ||
(github.event_name == 'workflow_dispatch')
name: Input value selection
runs-on: ubuntu-22.04
outputs:
# Output all the variables related to models and python versions to be used in the matrix strategy
# for different jobs, however their usage depends on the selected job.
selected_jobs: ${{ steps.input_selection.outputs.jobs_to_run }}
selected_models_for_tls: ${{ steps.input_selection.outputs.models_for_tls }}
selected_python_for_tls: ${{ steps.input_selection.outputs.python_for_tls }}
selected_models_for_non_tls: ${{ steps.input_selection.outputs.models_for_non_tls }}
selected_models_for_no_client_auth: ${{ steps.input_selection.outputs.models_for_no_client_auth }}
selected_models_for_memory_logs: ${{ steps.input_selection.outputs.models_for_memory_logs }}
selected_python_for_non_tls: ${{ steps.input_selection.outputs.python_for_non_tls }}
selected_python_for_no_client_auth: ${{ steps.input_selection.outputs.python_for_no_client_auth }}
selected_python_for_memory_logs: ${{ steps.input_selection.outputs.python_for_memory_logs }}
steps:
- name: Job to select input values
id: input_selection
run: |
# ---------------------------------------------------------------
# Models like XGBoost (xgb_higgs) and torch_cnn_histology require runners with higher memory and CPU to run.
# Thus these models are excluded from the matrix for now.
# Default combination if no input is provided (i.e. 'all' is selected).
# * TLS - models [torch_cnn_mnist, keras_cnn_mnist] and python versions [3.10, 3.11, 3.12]
# * Non-TLS - models [torch_cnn_mnist] and python version [3.10]
# * No client auth - models [keras_cnn_mnist] and python version [3.10]
# * Memory logs - models [torch_cnn_mnist] and python version [3.10]
# ---------------------------------------------------------------
echo "jobs_to_run=${{ env.JOBS_TO_RUN }}" >> "$GITHUB_OUTPUT"
if [ "${{ env.MODEL_NAME }}" == "all" ]; then
echo "models_for_tls=[\"torch_cnn_mnist\", \"keras_cnn_mnist\"]" >> "$GITHUB_OUTPUT"
echo "models_for_non_tls=[\"torch_cnn_mnist\"]" >> "$GITHUB_OUTPUT"
echo "models_for_no_client_auth=[\"keras_cnn_mnist\"]" >> "$GITHUB_OUTPUT"
echo "models_for_memory_logs=[\"torch_cnn_mnist\"]" >> "$GITHUB_OUTPUT"
else
echo "models_for_tls=[\"${{env.MODEL_NAME}}\"]" >> "$GITHUB_OUTPUT"
echo "models_for_non_tls=[\"${{env.MODEL_NAME}}\"]" >> "$GITHUB_OUTPUT"
echo "models_for_no_client_auth=[\"${{env.MODEL_NAME}}\"]" >> "$GITHUB_OUTPUT"
echo "models_for_memory_logs=[\"${{env.MODEL_NAME}}\"]" >> "$GITHUB_OUTPUT"
fi
if [ "${{ env.PYTHON_VERSION }}" == "all" ]; then
echo "python_for_tls=[\"3.10\", \"3.11\", \"3.12\"]" >> "$GITHUB_OUTPUT"
echo "python_for_non_tls=[\"3.10\"]" >> "$GITHUB_OUTPUT"
echo "python_for_no_client_auth=[\"3.10\"]" >> "$GITHUB_OUTPUT"
echo "python_for_memory_logs=[\"3.10\"]" >> "$GITHUB_OUTPUT"
else
echo "python_for_tls=[\"${{env.PYTHON_VERSION}}\"]" >> "$GITHUB_OUTPUT"
echo "python_for_non_tls=[\"${{env.PYTHON_VERSION}}\"]" >> "$GITHUB_OUTPUT"
echo "python_for_no_client_auth=[\"${{env.PYTHON_VERSION}}\"]" >> "$GITHUB_OUTPUT"
echo "python_for_memory_logs=[\"${{env.PYTHON_VERSION}}\"]" >> "$GITHUB_OUTPUT"
fi
test_with_tls:
name: tr_tls
name: Test with TLS
needs: input_selection
if: needs.input_selection.outputs.selected_jobs == 'all' || needs.input_selection.outputs.selected_jobs == 'test_with_tls'
runs-on: ubuntu-22.04
timeout-minutes: 30
strategy:
matrix:
# Models like XGBoost (xgb_higgs) and torch_cnn_histology require runners with higher memory and CPU to run.
# Thus these models are excluded from the matrix for now.
model_name: ["torch_cnn_mnist", "keras_cnn_mnist"]
python_version: ["3.10", "3.11", "3.12"]
model_name: ${{ fromJson(needs.input_selection.outputs.selected_models_for_tls) }}
python_version: ${{ fromJson(needs.input_selection.outputs.selected_python_for_tls) }}
fail-fast: false # do not immediately fail if one of the combinations fail

env:
Expand Down Expand Up @@ -72,15 +162,15 @@ jobs:
test_type: "tr_tls"

test_with_non_tls:
name: tr_non_tls
name: Test without TLS
needs: input_selection
if: needs.input_selection.outputs.selected_jobs == 'all' || needs.input_selection.outputs.selected_jobs == 'test_with_non_tls'
runs-on: ubuntu-22.04
timeout-minutes: 30
strategy:
matrix:
# Testing this scenario only for torch_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["torch_cnn_mnist"]
python_version: ["3.10"]
model_name: ${{ fromJson(needs.input_selection.outputs.selected_models_for_non_tls) }}
python_version: ${{ fromJson(needs.input_selection.outputs.selected_python_for_non_tls) }}
fail-fast: false # do not immediately fail if one of the combinations fail

env:
Expand Down Expand Up @@ -115,15 +205,15 @@ jobs:
test_type: "tr_non_tls"

test_with_no_client_auth:
name: tr_no_client_auth
name: Test without client auth
needs: input_selection
if: needs.input_selection.outputs.selected_jobs == 'all' || needs.input_selection.outputs.selected_jobs == 'test_with_no_client_auth'
runs-on: ubuntu-22.04
timeout-minutes: 30
strategy:
matrix:
# Testing this scenario for keras_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["keras_cnn_mnist"]
python_version: ["3.10"]
model_name: ${{ fromJson(needs.input_selection.outputs.selected_models_for_no_client_auth) }}
python_version: ${{ fromJson(needs.input_selection.outputs.selected_python_for_no_client_auth) }}
fail-fast: false # do not immediately fail if one of the combinations fail

env:
Expand Down Expand Up @@ -155,18 +245,18 @@ jobs:
uses: ./.github/actions/tr_post_test_run
if: ${{ always() }}
with:
test_type: "tr_no_client_auth"
test_type: 'tr_no_client_auth'

test_memory_logs:
name: tr_tls_memory_logs
name: Test memory usage
needs: input_selection
if: needs.input_selection.outputs.selected_jobs == 'all' || needs.input_selection.outputs.selected_jobs == 'test_memory_logs'
runs-on: ubuntu-22.04
timeout-minutes: 30
strategy:
matrix:
# Testing this scenario only for torch_cnn_mnist model and python 3.10
# If required, this can be extended to other models and python versions
model_name: ["torch_cnn_mnist"]
python_version: ["3.10"]
model_name: ${{ fromJson(needs.input_selection.outputs.selected_models_for_memory_logs) }}
python_version: ${{ fromJson(needs.input_selection.outputs.selected_python_for_memory_logs) }}
fail-fast: false # do not immediately fail if one of the combinations fail

env:
Expand Down
39 changes: 32 additions & 7 deletions tests/end_to_end/utils/summary_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ def get_aggregated_accuracy(agg_log_file):
return agg_accuracy

agg_accuracy_dict = convert_to_json(agg_log_file)
agg_accuracy = agg_accuracy_dict[-1].get(
"aggregator/aggregated_model_validation/accuracy", "Not Found"
)

if not agg_accuracy_dict:
print(f"Aggregator log file {agg_log_file} is empty. Cannot get aggregated accuracy, returning 'Not Found'")
else:
agg_accuracy = agg_accuracy_dict[-1].get(
"aggregator/aggregated_model_validation/accuracy", "Not Found"
)
return agg_accuracy


Expand Down Expand Up @@ -104,7 +108,7 @@ def get_testcase_result():

def print_task_runner_score():
"""
Main function to get the test case results and aggregator logs
Function to get the test case results and aggregator logs
And write the results to GitHub step summary
IMP: Do not fail the test in any scenario
"""
Expand All @@ -129,7 +133,7 @@ def print_task_runner_score():
num_cols = os.getenv("NUM_COLLABORATORS")
num_rounds = os.getenv("NUM_ROUNDS")
model_name = os.getenv("MODEL_NAME")
summary_file = os.getenv("GITHUB_STEP_SUMMARY")
summary_file = _get_summary_file()

# Validate the model name and create the workspace name
if not model_name.upper() in constants.ModelName._member_names_:
Expand Down Expand Up @@ -169,8 +173,12 @@ def print_task_runner_score():


def print_federated_runtime_score():
summary_file = os.getenv("GITHUB_STEP_SUMMARY")

"""
Function to get the federated runtime score from the director log file
And write the results to GitHub step summary
IMP: Do not fail the test in any scenario
"""
summary_file = _get_summary_file()
search_string = "Aggregated model validation score"

last_occurrence = aggregated_model_score = None
Expand Down Expand Up @@ -210,6 +218,23 @@ def print_federated_runtime_score():
print(f"| {aggregated_model_score} |", file=fh)


def _get_summary_file():
"""
Function to get the summary file path
Returns:
summary_file: Path to the summary file
"""
summary_file = os.getenv("GITHUB_STEP_SUMMARY")
print(f"Summary file: {summary_file}")

# Check if the fetched summary file is valid
if summary_file and os.path.isfile(summary_file):
return summary_file
else:
print("Invalid summary file. Exiting...")
exit(1)


def fetch_args():
"""
Function to fetch the commandline arguments.
Expand Down

0 comments on commit 4fbecc0

Please sign in to comment.