Skip to content

Commit

Permalink
Fixed issue with deep-r and epoch-by-epoch training
Browse files Browse the repository at this point in the history
* DeepRInit should trigger at beginning of first epoch, not at start or training (in fact, start of training is a silly time to start any kind of custom update)
* Refactor callbacks to take a filter function and replaced bespoke callbacks used in EventProp compiler
  • Loading branch information
neworderofjamie committed Sep 6, 2024
1 parent 9ac44cd commit ed29059
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 77 deletions.
8 changes: 3 additions & 5 deletions ml_genn/ml_genn/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from .conn_var_recorder import ConnVarRecorder
from .custom_update import (CustomUpdateOnBatchBegin, CustomUpdateOnBatchEnd,
CustomUpdateOnEpochBegin, CustomUpdateOnEpochEnd,
CustomUpdateOnTimestepBegin, CustomUpdateOnTimestepEnd,
CustomUpdateOnTrainBegin)
CustomUpdateOnTimestepBegin, CustomUpdateOnTimestepEnd)
from .optimiser_param_schedule import OptimiserParamSchedule
from .progress_bar import BatchProgressBar
from .spike_recorder import SpikeRecorder
Expand All @@ -18,6 +17,5 @@
__all__ = ["Callback", "Checkpoint", "ConnVarRecorder",
"CustomUpdateOnBatchBegin", "CustomUpdateOnBatchEnd",
"CustomUpdateOnTimestepBegin", "CustomUpdateOnTimestepEnd",
"CustomUpdateOnTrainBegin", "OptimiserParamSchedule",
"BatchProgressBar", "SpikeRecorder", "VarRecorder",
"default_callbacks"]
"OptimiserParamSchedule", "BatchProgressBar", "SpikeRecorder",
"VarRecorder", "default_callbacks"]
61 changes: 32 additions & 29 deletions ml_genn/ml_genn/callbacks/custom_update.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from typing import Callable
from .callback import Callback

logger = logging.getLogger(__name__)
Expand All @@ -9,74 +10,76 @@ class CustomUpdate(Callback):
"""Base class for callbacks that trigger a GeNN custom update.
Args:
name: Name of custom update to trigger
name: Name of custom update to trigger
filter_fn: Filtering function to determine which
epochs/timesteps/batches to trigger on
"""
def __init__(self, name: str):
def __init__(self, name: str, filter_fn: Callable[[int], bool] = None):
self.name = name
self.filter_fn = filter_fn

def set_params(self, compiled_network, **kwargs):
# Extract compiled network
self._compiled_network = compiled_network

def _custom_update(self):
self._compiled_network.genn_model.custom_update(self.name)
def _custom_update(self, number):
# If there is no filter or filter return true, launch custom update
if self.filter_fn is None or self.filter_fn(number):
self._compiled_network.genn_model.custom_update(self.name)
return True
else:
return False


class CustomUpdateOnBatchBegin(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the beginning of every batch."""
def on_batch_begin(self, batch):
logger.debug(f"Running custom update {self.name} "
f"at start of batch {batch}")
self._custom_update()
if self._custom_update(batch):
logger.debug(f"Running custom update {self.name} "
f"at start of batch {batch}")


class CustomUpdateOnBatchEnd(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the end of every batch."""
def on_batch_end(self, batch, metrics):
logger.debug(f"Running custom update {self.name} "
f"at end of batch {batch}")
self._custom_update()

if self._custom_update(batch):
logger.debug(f"Running custom update {self.name} "
f"at end of batch {batch}")

class CustomUpdateOnEpochBegin(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the beginning of every epoch."""
def on_epoch_begin(self, epoch):
logger.debug(f"Running custom update {self.name} "
f"at start of epoch {epoch}")
self._custom_update()
if self._custom_update(epoch):
logger.debug(f"Running custom update {self.name} "
f"at start of epoch {epoch}")


class CustomUpdateOnEpochEnd(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the end of every epoch."""
def on_epoch_end(self, epoch, metrics):
logger.debug(f"Running custom update {self.name} "
f"at end of epoch {epoch}")
self._custom_update()
if self._custom_update(epoch):
logger.debug(f"Running custom update {self.name} "
f"at end of epoch {epoch}")


class CustomUpdateOnTimestepBegin(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the beginning of every timestep."""
def on_timestep_begin(self, timestep):
logger.debug(f"Running custom update {self.name} "
f"at start of timestep {timestep}")
self._custom_update()
if self._custom_update(timestep):
logger.debug(f"Running custom update {self.name} "
f"at start of timestep {timestep}")



class CustomUpdateOnTimestepEnd(CustomUpdate):
"""Callback that triggers a GeNN custom update
at the end of every timestep."""
def on_timestep_end(self, timestep):
logger.debug(f"Running custom update {self.name} "
f"at start of timestep {timestep}")
self._custom_update()

class CustomUpdateOnTrainBegin(CustomUpdate):
def on_train_begin(self):
logger.debug(f"Running custom update {self.name} "
f"at start of training")
self._custom_update()
if self._custom_update(timestep):
logger.debug(f"Running custom update {self.name} "
f"at start of timestep {timestep}")
6 changes: 3 additions & 3 deletions ml_genn/ml_genn/compilers/eprop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from .deep_r import RewiringRecord
from .. import Connection, Population, Network
from ..callbacks import (BatchProgressBar, CustomUpdateOnBatchBegin,
CustomUpdateOnBatchEnd, CustomUpdateOnTimestepEnd,
CustomUpdateOnTrainBegin)
CustomUpdateOnBatchEnd, CustomUpdateOnTimestepEnd)
from ..communicators import Communicator
from ..losses import Loss, SparseCategoricalCrossentropy
from ..neurons import (AdaptiveLeakyIntegrateFire, Input,
Expand Down Expand Up @@ -608,7 +607,8 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,

# If Deep-R is required, trigger Deep-R callbacks at end of batch
if deep_r_required:
base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit"))
base_train_callbacks.append(CustomUpdateOnEpoch("DeepRInit",
lambda e: e == 0))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2"))

Expand Down
54 changes: 14 additions & 40 deletions ml_genn/ml_genn/compilers/event_prop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .. import Connection, Population, Network
from ..callbacks import (BatchProgressBar, Callback, CustomUpdateOnBatchBegin,
CustomUpdateOnBatchEnd, CustomUpdateOnEpochEnd,
CustomUpdateOnTimestepBegin,
CustomUpdateOnTimestepEnd)
from ..communicators import Communicator
from ..connection import Connection
Expand Down Expand Up @@ -211,40 +212,6 @@ def on_batch_begin(self, batch: int):
self.genn_pop.set_dynamic_param_value("Trial", batch)


class CustomUpdateOnLastTimestep(Callback):
"""Callback that triggers a GeNN custom update
at the start of the last timestep in each example"""
def __init__(self, name: str, example_timesteps: int):
self.name = name
self.example_timesteps = example_timesteps

def set_params(self, compiled_network, **kwargs):
# Extract compiled network
self._compiled_network = compiled_network

def on_timestep_begin(self, timestep: int):
if timestep == (self.example_timesteps - 1):
logger.debug(f"Running custom update {self.name} "
f"at start of timestep {timestep}")
self._compiled_network.genn_model.custom_update(self.name)


class CustomUpdateOnBatchEndNotFirst(Callback):
"""Callback that triggers a GeNN custom update
at the end of every batch after the first."""
def __init__(self, name: str):
self.name = name

def set_params(self, compiled_network, **kwargs):
# Extract compiled network
self._compiled_network = compiled_network

def on_batch_end(self, batch, metrics):
if batch > 0:
logger.debug(f"Running custom update {self.name} "
f"at end of batch {batch}")
self._compiled_network.genn_model.custom_update(self.name)

# Standard EventProp weight update model
# **NOTE** feedback is added if required
weight_update_model = {
Expand Down Expand Up @@ -1185,14 +1152,17 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,

# If Deep-R and L1 regularisation are required, add callback
if deep_r_required and self.deep_r_l1_strength > 0.0:
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepRL1"))
base_train_callbacks.append(
CustomUpdateOnBatchEnd("DeepRL1", lambda batch: batch > 0))

if len(weight_optimiser_cus) > 0 or len(delay_optimiser_cus) > 0:
if self.full_batch_size > 1:
base_train_callbacks.append(
CustomUpdateOnBatchEndNotFirst("GradientBatchReduce"))
CustomUpdateOnBatchEnd("GradientBatchReduce",
lambda batch: batch > 0))
base_train_callbacks.append(
CustomUpdateOnBatchEndNotFirst("GradientLearn"))
CustomUpdateOnBatchEnd("GradientLearn",
lambda batch: batch > 0))
base_validate_callbacks.append(
CustomUpdateOnEpochEnd("ZeroGradient"))

Expand All @@ -1202,10 +1172,13 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,
base_train_callbacks.append(UpdateTrial(neuron_populations[p]))

# Add callbacks to zero out post on all connections
last_timestep = self.example_timesteps - 1
base_train_callbacks.append(
CustomUpdateOnLastTimestep("ZeroOutPost", self.example_timesteps))
CustomUpdateOnTimestepBegin("ZeroOutPost",
lambda t: t == last_timestep))
base_validate_callbacks.append(
CustomUpdateOnLastTimestep("ZeroOutPost", self.example_timesteps))
CustomUpdateOnTimestepBegin("ZeroOutPost",
lambda t: t == last_timestep))

# If softmax calculation is required at end of batch, add callbacks
if len(compile_state.batch_softmax_populations) > 0:
Expand All @@ -1226,7 +1199,8 @@ def create_compiled_network(self, genn_model, neuron_populations: dict,

# If Deep-R is required, trigger Deep-R callbacks at end of batch
if deep_r_required:
base_train_callbacks.append(CustomUpdateOnTrainBegin("DeepRInit"))
base_train_callbacks.append(CustomUpdateOnEpochBegin("DeepRInit",
lambda e: e == 0))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR1"))
base_train_callbacks.append(CustomUpdateOnBatchEnd("DeepR2"))

Expand Down

0 comments on commit ed29059

Please sign in to comment.