From eaa29db7b8a72f3ca263951f1c82df72c50365cc Mon Sep 17 00:00:00 2001 From: neworderofjamie Date: Mon, 24 Apr 2023 18:52:19 +0100 Subject: [PATCH] ``GeNNModel.unload`` should clear out all simulation state from ``GeNNModel`` and unload shared library model --- pygenn/genn_groups.py | 60 +++++++++++++++++++++++- pygenn/genn_model.py | 25 +++++++++- userproject/include/sharedLibraryModel.h | 45 ++++++++++++++---- 3 files changed, 119 insertions(+), 11 deletions(-) diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index 41e2fd86ca..a9a517d210 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -321,6 +321,25 @@ def _load_var_init_egps(self, var_dict=None): for var_name, var_data in iteritems(var_dict): self._load_egp(var_data.extra_global_params, var_name) + def _unload_vars(self, var_dict=None): + # If no variable dictionary is specified, use standard one + if var_dict is None: + var_dict = self.vars + + # Loop through variables and clear views + for v in itervalues(var_dict): + v.view = None + for e in itervalues(v.extra_global_params): + e.view = None + + def _unload_egps(self, egp_dict=None): + # If no EGP dictionary is specified, use standard one + if egp_dict is None: + egp_dict = self.extra_global_params + + # Loop through extra global params and clear views + for e in itervalues(egp_dict): + e.view = None class NeuronGroup(Group): @@ -339,12 +358,14 @@ def __init__(self, name, model): self.spike_count = None self.spike_events = None self.spike_event_count = None - self.spike_que_ptr = [0] + self.spike_que_ptr = None self._max_delay_steps = 0 self.spike_times = None self.prev_spike_times = None self.spike_event_times = None self.prev_spike_event_times = None + self._spike_recording_data = None + self._spike_event_recording_data = None @property def current_spikes(self): @@ -591,6 +612,22 @@ def load(self, num_recording_timesteps): # Load neuron extra global params self._load_egp() + def unload(self): + self.spikes = None + self.spike_count = None + self.spike_events = None + self.spike_event_count = None + self.spike_que_ptr = None + self.spike_times = None + self.prev_spike_times = None + self.spike_event_times = None + self.prev_spike_event_times = None + self._spike_recording_data = None + self._spike_event_recording_data = None + + self._unload_vars() + self._unload_egps() + def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() @@ -1293,6 +1330,19 @@ def load_init_egps(self): self._load_var_init_egps(self.pre_vars) self._load_var_init_egps(self.post_vars) + def unload(self): + self._ind = None + self._row_lengths = None + self.in_syn = None + + self._unload_vars() + self._unload_vars(self.pre_vars) + self._unload_vars(self.post_vars) + self._unload_vars(self.psm_vars) + self._unload_egps() + self._unload_egps(self.psm_extra_global_params) + self._unload_egps(self.connectivity_extra_global_params) + def reinitialise(self): """Reinitialise synapse group""" # If population has individual synapse variables @@ -1416,6 +1466,10 @@ def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() + def unload(self): + self._unload_vars() + self._unload_egps() + def reinitialise(self): """Reinitialise current source""" # Reinitialise current source state variables @@ -1547,6 +1601,10 @@ def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() + def unload(self): + self._unload_vars() + self._unload_egps() + def reinitialise(self): """Reinitialise custom update""" # If this is a custom weight update diff --git a/pygenn/genn_model.py b/pygenn/genn_model.py index d50792f3e8..39637fbefd 100644 --- a/pygenn/genn_model.py +++ b/pygenn/genn_model.py @@ -646,7 +646,7 @@ def load(self, path_to_model="./", num_recording_timesteps=None): # Allocate recording buffers self._slm.allocate_recording_buffers(num_recording_timesteps) - # Loop through synapse populations and load any + # Loop through neuron populations and load any # extra global parameters required for initialization for pop_data in itervalues(self.neuron_populations): pop_data.load_init_egps() @@ -690,6 +690,29 @@ def load(self, path_to_model="./", num_recording_timesteps=None): self._loaded = True self._built = True + def unload(self): + # Loop through custom updates and unload + for cu_data in itervalues(self.custom_updates): + cu_data.unload() + + # Loop through current sources and unload + for src_data in itervalues(self.current_sources): + src_data.unload() + + # Loop through synapse populations and unload + for pop_data in itervalues(self.synapse_populations): + pop_data.unload() + + # Loop through neuron populations and unload + for pop_data in itervalues(self.neuron_populations): + pop_data.unload() + + # Close shared library model + self._slm.close() + + # Clear loaded flag + self._loaded = False + def reinitialise(self): """reinitialise model to its original state without re-loading""" if not self._loaded: diff --git a/userproject/include/sharedLibraryModel.h b/userproject/include/sharedLibraryModel.h index 28f37f8734..7a7866947d 100644 --- a/userproject/include/sharedLibraryModel.h +++ b/userproject/include/sharedLibraryModel.h @@ -50,15 +50,8 @@ class SharedLibraryModel virtual ~SharedLibraryModel() { - // Close model library if loaded successfully - if(m_Library) { - freeMem(); -#ifdef _WIN32 - FreeLibrary(m_Library); -#else - dlclose(m_Library); -#endif - } + // Close model library + close(); } //---------------------------------------------------------------------------- @@ -112,6 +105,40 @@ class SharedLibraryModel } } + void close() + { + if(m_Library) { + freeMem(); +#ifdef _WIN32 + FreeLibrary(m_Library); +#else + dlclose(m_Library); +#endif + m_Library = nullptr; + } + + // Null all pointers + m_AllocateMem = nullptr; + m_AllocateRecordingBuffers = nullptr; + m_FreeMem = nullptr; + m_GetFreeDeviceMemBytes = nullptr; + m_Initialize = nullptr; + m_InitializeSparse = nullptr; + m_StepTime = nullptr; + m_PullRecordingBuffersFromDevice = nullptr; + m_NCCLGenerateUniqueID = nullptr; + m_NCCLGetUniqueID = nullptr; + m_NCCLInitCommunicator = nullptr; + m_NCCLUniqueIDBytes = nullptr; + m_T = nullptr; + m_Timestep = nullptr; + + // Empty all dictionaries + m_PopulationVars.clear(); + m_PopulationEPGs.clear(); + m_CustomUpdates.clear(); + } + void allocateExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count) { // Get EGP functions and check allocate exists