-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes #718 WandB in OpenFL and example to change number of epochs per round #895
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,6 +18,8 @@ | |||||
from openfl.utilities import TensorKey | ||||||
from openfl.utilities.logs import write_metric | ||||||
|
||||||
import wandb | ||||||
LOG_WANDB = True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be configurable by the definer of the experiment, i.e. pass this through the OpenFL Plan |
||||||
|
||||||
class Aggregator: | ||||||
r"""An Aggregator is the central node in federated learning. | ||||||
|
@@ -55,6 +57,17 @@ def __init__(self, | |||||
log_metric_callback=None, | ||||||
**kwargs): | ||||||
"""Initialize.""" | ||||||
#INITIALIZE WANDB WITH CORRECT NAME | ||||||
#The following variable, my_aggregator_name, can be changed to whatever you want. Right now I suppose that the name of the collaborators of the federation is in the format DATASETNAME_ENV_NUMBER | ||||||
my_aggregator_name = '_'.join(set(element.split('_')[0] for element in authorized_cols)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be moved into the
Suggested change
|
||||||
if LOG_WANDB: | ||||||
wandb.init(project="my_project", entity="my_group", group=f"{my_aggregator_name}", tags=["my_tag"], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name of the project, entity, group, etc. should also be configurable through the plan |
||||||
config={ | ||||||
"num_clients": 4, | ||||||
"rounds": 100 | ||||||
Comment on lines
+66
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not familiar with the wandb configuration, but the same guidance applies here as earlier comments. I would expect |
||||||
}, | ||||||
name=f"Aggregator_{my_aggregator_name}" | ||||||
) | ||||||
self.round_number = 0 | ||||||
self.single_col_cert_common_name = single_col_cert_common_name | ||||||
|
||||||
|
@@ -837,6 +850,8 @@ def _compute_validation_related_task_metrics(self, task_name): | |||||
if agg_function: | ||||||
self.logger.metric(f'Round {round_number}, aggregator: {task_name} ' | ||||||
f'{agg_function} {agg_tensor_name}:\t{agg_results:f}') | ||||||
if LOG_WANDB: | ||||||
wandb.log({f"{task_name} {agg_tensor_name}": float(f"{agg_results}")}, step=round_number) | ||||||
else: | ||||||
self.logger.metric(f'Round {round_number}, aggregator: {task_name} ' | ||||||
f'{agg_tensor_name}:\t{agg_results:f}') | ||||||
|
@@ -893,6 +908,8 @@ def _end_of_round_check(self): | |||||
# TODO This needs to be fixed! | ||||||
if self._time_to_quit(): | ||||||
self.logger.info('Experiment Completed. Cleaning up...') | ||||||
if LOG_WANDB: | ||||||
wandb.finish() | ||||||
else: | ||||||
self.logger.info(f'Starting round {self.round_number}...') | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
from openfl.protocols import utils | ||
from openfl.utilities import TensorKey | ||
|
||
import wandb | ||
LOG_WANDB = True | ||
Comment on lines
+17
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. Should only import this if |
||
|
||
class DevicePolicy(Enum): | ||
"""Device assignment policy.""" | ||
|
@@ -133,6 +135,14 @@ def set_available_devices(self, cuda: Tuple[str] = ()): | |
|
||
def run(self): | ||
"""Run the collaborator.""" | ||
if LOG_WANDB: | ||
wandb.init(project="my_project", entity="my_group", tags=["my_tags"], | ||
config={ | ||
"num_clients": 4, | ||
"rounds": 100, | ||
}, | ||
name=self.collaborator_name | ||
Comment on lines
+139
to
+144
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. Make project, entity, tags configurable. Use client count and rounds from existing plan variables. |
||
) | ||
while True: | ||
tasks, round_number, sleep_time, time_to_quit = self.get_tasks() | ||
if time_to_quit: | ||
|
@@ -148,6 +158,8 @@ def run(self): | |
self.tensor_db.clean_up(self.db_store_rounds) | ||
|
||
self.logger.info('End of Federation reached. Exiting...') | ||
if LOG_WANDB: | ||
wandb.finish() | ||
|
||
def run_simulation(self): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would require
wandb
as an import for OpenFL, which is something we want to avoid.