Skip to content

Commit

Permalink
Merge pull request #581 from genn-team/pygenn_unload
Browse files Browse the repository at this point in the history
Ability to unload PyGeNN models
  • Loading branch information
neworderofjamie authored May 15, 2023
2 parents ef69e59 + eaa29db commit cf51800
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 11 deletions.
60 changes: 59 additions & 1 deletion pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,25 @@ def _load_var_init_egps(self, var_dict=None):
for var_name, var_data in iteritems(var_dict):
self._load_egp(var_data.extra_global_params, var_name)

def _unload_vars(self, var_dict=None):
# If no variable dictionary is specified, use standard one
if var_dict is None:
var_dict = self.vars

# Loop through variables and clear views
for v in itervalues(var_dict):
v.view = None
for e in itervalues(v.extra_global_params):
e.view = None

def _unload_egps(self, egp_dict=None):
# If no EGP dictionary is specified, use standard one
if egp_dict is None:
egp_dict = self.extra_global_params

# Loop through extra global params and clear views
for e in itervalues(egp_dict):
e.view = None

class NeuronGroup(Group):

Expand All @@ -339,12 +358,14 @@ def __init__(self, name, model):
self.spike_count = None
self.spike_events = None
self.spike_event_count = None
self.spike_que_ptr = [0]
self.spike_que_ptr = None
self._max_delay_steps = 0
self.spike_times = None
self.prev_spike_times = None
self.spike_event_times = None
self.prev_spike_event_times = None
self._spike_recording_data = None
self._spike_event_recording_data = None

@property
def current_spikes(self):
Expand Down Expand Up @@ -591,6 +612,22 @@ def load(self, num_recording_timesteps):
# Load neuron extra global params
self._load_egp()

def unload(self):
self.spikes = None
self.spike_count = None
self.spike_events = None
self.spike_event_count = None
self.spike_que_ptr = None
self.spike_times = None
self.prev_spike_times = None
self.spike_event_times = None
self.prev_spike_event_times = None
self._spike_recording_data = None
self._spike_event_recording_data = None

self._unload_vars()
self._unload_egps()

def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()
Expand Down Expand Up @@ -1293,6 +1330,19 @@ def load_init_egps(self):
self._load_var_init_egps(self.pre_vars)
self._load_var_init_egps(self.post_vars)

def unload(self):
self._ind = None
self._row_lengths = None
self.in_syn = None

self._unload_vars()
self._unload_vars(self.pre_vars)
self._unload_vars(self.post_vars)
self._unload_vars(self.psm_vars)
self._unload_egps()
self._unload_egps(self.psm_extra_global_params)
self._unload_egps(self.connectivity_extra_global_params)

def reinitialise(self):
"""Reinitialise synapse group"""
# If population has individual synapse variables
Expand Down Expand Up @@ -1416,6 +1466,10 @@ def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()

def unload(self):
self._unload_vars()
self._unload_egps()

def reinitialise(self):
"""Reinitialise current source"""
# Reinitialise current source state variables
Expand Down Expand Up @@ -1547,6 +1601,10 @@ def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()

def unload(self):
self._unload_vars()
self._unload_egps()

def reinitialise(self):
"""Reinitialise custom update"""
# If this is a custom weight update
Expand Down
25 changes: 24 additions & 1 deletion pygenn/genn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def load(self, path_to_model="./", num_recording_timesteps=None):
# Allocate recording buffers
self._slm.allocate_recording_buffers(num_recording_timesteps)

# Loop through synapse populations and load any
# Loop through neuron populations and load any
# extra global parameters required for initialization
for pop_data in itervalues(self.neuron_populations):
pop_data.load_init_egps()
Expand Down Expand Up @@ -690,6 +690,29 @@ def load(self, path_to_model="./", num_recording_timesteps=None):
self._loaded = True
self._built = True

def unload(self):
# Loop through custom updates and unload
for cu_data in itervalues(self.custom_updates):
cu_data.unload()

# Loop through current sources and unload
for src_data in itervalues(self.current_sources):
src_data.unload()

# Loop through synapse populations and unload
for pop_data in itervalues(self.synapse_populations):
pop_data.unload()

# Loop through neuron populations and unload
for pop_data in itervalues(self.neuron_populations):
pop_data.unload()

# Close shared library model
self._slm.close()

# Clear loaded flag
self._loaded = False

def reinitialise(self):
"""reinitialise model to its original state without re-loading"""
if not self._loaded:
Expand Down
45 changes: 36 additions & 9 deletions userproject/include/sharedLibraryModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,8 @@ class SharedLibraryModel

virtual ~SharedLibraryModel()
{
// Close model library if loaded successfully
if(m_Library) {
freeMem();
#ifdef _WIN32
FreeLibrary(m_Library);
#else
dlclose(m_Library);
#endif
}
// Close model library
close();
}

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -112,6 +105,40 @@ class SharedLibraryModel
}
}

void close()
{
if(m_Library) {
freeMem();
#ifdef _WIN32
FreeLibrary(m_Library);
#else
dlclose(m_Library);
#endif
m_Library = nullptr;
}

// Null all pointers
m_AllocateMem = nullptr;
m_AllocateRecordingBuffers = nullptr;
m_FreeMem = nullptr;
m_GetFreeDeviceMemBytes = nullptr;
m_Initialize = nullptr;
m_InitializeSparse = nullptr;
m_StepTime = nullptr;
m_PullRecordingBuffersFromDevice = nullptr;
m_NCCLGenerateUniqueID = nullptr;
m_NCCLGetUniqueID = nullptr;
m_NCCLInitCommunicator = nullptr;
m_NCCLUniqueIDBytes = nullptr;
m_T = nullptr;
m_Timestep = nullptr;

// Empty all dictionaries
m_PopulationVars.clear();
m_PopulationEPGs.clear();
m_CustomUpdates.clear();
}

void allocateExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count)
{
// Get EGP functions and check allocate exists
Expand Down

0 comments on commit cf51800

Please sign in to comment.