Skip to content

Commit

Permalink
Introduce callbacks API (securefederatedai#1195)
Browse files Browse the repository at this point in the history
* Get rid of kwargs

Signed-off-by: Shah, Karan <[email protected]>

* Use module-level logger

Signed-off-by: Shah, Karan <[email protected]>

* Reduce keras verbosity

Signed-off-by: Shah, Karan <[email protected]>

* Remove all log_metric and log_memory_usage traces; add callback hooks

Signed-off-by: Shah, Karan <[email protected]>

* Add `openfl.callbacks` module

Signed-off-by: Shah, Karan <[email protected]>

* Include round_num for task callbacks

Signed-off-by: Shah, Karan <[email protected]>

* Add tensordb to callbacks

Signed-off-by: Shah, Karan <[email protected]>

* No round_num on task callbacks

Signed-off-by: Shah, Karan <[email protected]>

* Remove task boundary callbacks

Signed-off-by: Shah, Karan <[email protected]>

* Remove tb/model_ckpt. Add memory_profiler

Signed-off-by: Shah, Karan <[email protected]>

* Restore psutil and tbX

Signed-off-by: Shah, Karan <[email protected]>

* Format code

Signed-off-by: Shah, Karan <[email protected]>

* Define default callbacks

Signed-off-by: Shah, Karan <[email protected]>

* Add write_logs for bwd compat

Signed-off-by: Shah, Karan <[email protected]>

* Add log_metric_callback for bwd compat

Signed-off-by: Shah, Karan <[email protected]>

* Migrate to module-level logger for collaborator

Signed-off-by: Shah, Karan <[email protected]>

* Review comments

Signed-off-by: Shah, Karan <[email protected]>

* Add metric_writer

Signed-off-by: Shah, Karan <[email protected]>

* Add collaborator side metric logging

Signed-off-by: Shah, Karan <[email protected]>

* Make log dirs on exp begin

Signed-off-by: Shah, Karan <[email protected]>

* Do not print use_tls

Signed-off-by: Shah, Karan <[email protected]>

* Assume reportable metric to be a scalar

Signed-off-by: Shah, Karan <[email protected]>

* Add aggregator side callbacks

Signed-off-by: Shah, Karan <[email protected]>

* do_task test returns mock dict

Signed-off-by: Shah, Karan <[email protected]>

* Consistency changes

Signed-off-by: Shah, Karan <[email protected]>

* Add documentation hooks

Signed-off-by: Shah, Karan <[email protected]>

* Update docstring

Signed-off-by: Shah, Karan <[email protected]>

* Update docs hook

Signed-off-by: Shah, Karan <[email protected]>

* Remove all traces of log_metric_callback and write_metric

Signed-off-by: Shah, Karan <[email protected]>

* Do on_round_begin if not time_to_quit

Signed-off-by: Shah, Karan <[email protected]>

---------

Signed-off-by: Shah, Karan <[email protected]>
  • Loading branch information
MasterSkepticista authored Dec 23, 2024
1 parent 4e021d2 commit c280f10
Show file tree
Hide file tree
Showing 37 changed files with 510 additions and 482 deletions.
16 changes: 16 additions & 0 deletions docs/openfl.callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
``openfl.callbacks`` module
===========================

.. currentmodule:: openfl.callbacks

.. automodule:: openfl.callbacks

.. autosummary::
:toctree: _autosummary
:recursive:

Callback
CallbackList
LambdaCallback
MetricWriter
MemoryProfiler
3 changes: 2 additions & 1 deletion docs/openfl.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. currentmodule:: openfl

Public API: ``openfl`` package
API Reference: ``openfl``
===========================

Subpackages
Expand All @@ -10,6 +10,7 @@ Subpackages
:maxdepth: 1

openfl.component
openfl.callbacks
openfl.cryptography
openfl.experimental
openfl.databases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.aggregator.Aggregator
settings :
rounds_to_train : 1
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ aggregator :
template : openfl.experimental.workflow.component.aggregator.Aggregator
settings :
rounds_to_train : 10
log_metric_callback :
template : src.utils.write_metric


collaborator :
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ aggregator :
best_state_path : save/torch_cnn_mnist_best.pbuf
last_state_path : save/torch_cnn_mnist_last.pbuf
rounds_to_train : 10
log_metric_callback :
template : src.mnist_utils.write_metric


collaborator :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,11 @@
from logging import getLogger

import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision import transforms

logger = getLogger(__name__)

writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)


def one_hot(labels, classes):
"""
Expand Down
2 changes: 0 additions & 2 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ aggregator :
init_state_path : save/torch_cnn_mnist_init.pbuf
best_state_path : save/torch_cnn_mnist_best.pbuf
last_state_path : save/torch_cnn_mnist_last.pbuf
log_metric_callback :
template : src.mnist_utils.write_metric

collaborator :
defaults : plan/defaults/collaborator.yaml
Expand Down
16 changes: 0 additions & 16 deletions openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,11 @@
from logging import getLogger

import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision import transforms

logger = getLogger(__name__)

writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)


def one_hot(labels, classes):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ aggregator :
best_state_path : save/torch_cnn_mnist_best.pbuf
last_state_path : save/torch_cnn_mnist_last.pbuf
rounds_to_train : 6
log_metric_callback :
template : src.mnist_utils.write_metric


collaborator :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,11 @@
from logging import getLogger

import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision import transforms

logger = getLogger(__name__)

writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)


def one_hot(labels, classes):
"""
Expand Down
2 changes: 0 additions & 2 deletions openfl-workspace/torch_llm_horovod/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ aggregator :
best_state_path : save/torch_llm_best.pbuf
last_state_path : save/torch_llm_last.pbuf
rounds_to_train : 5
log_metric_callback :
template : src.emotion_utils.write_metric


collaborator :
Expand Down
16 changes: 0 additions & 16 deletions openfl-workspace/torch_llm_horovod/src/emotion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,10 @@
from logging import getLogger

from datasets import Dataset, load_dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, DataCollatorWithPadding

logger = getLogger(__name__)

writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter("./logs/llm", flush_secs=5)


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
get_writer()
writer.add_scalar(f"{node_name}/{task_name}/{metric_name}", metric, round_number)


def get_emotion_dataset(tokenizer):
dataset = load_dataset("dair-ai/emotion", cache_dir="dataset", revision="9ce6303")
Expand Down
1 change: 0 additions & 1 deletion openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
template : openfl.component.Aggregator
settings :
db_store_rounds : 2
write_logs : true
Loading

0 comments on commit c280f10

Please sign in to comment.