Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to unload PyGeNN models #581

Merged
merged 1 commit into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,25 @@ def _load_var_init_egps(self, var_dict=None):
for var_name, var_data in iteritems(var_dict):
self._load_egp(var_data.extra_global_params, var_name)

def _unload_vars(self, var_dict=None):
# If no variable dictionary is specified, use standard one
if var_dict is None:
var_dict = self.vars

# Loop through variables and clear views
for v in itervalues(var_dict):
v.view = None
for e in itervalues(v.extra_global_params):
e.view = None

def _unload_egps(self, egp_dict=None):
# If no EGP dictionary is specified, use standard one
if egp_dict is None:
egp_dict = self.extra_global_params

# Loop through extra global params and clear views
for e in itervalues(egp_dict):
e.view = None

class NeuronGroup(Group):

Expand All @@ -339,12 +358,14 @@ def __init__(self, name, model):
self.spike_count = None
self.spike_events = None
self.spike_event_count = None
self.spike_que_ptr = [0]
self.spike_que_ptr = None
self._max_delay_steps = 0
self.spike_times = None
self.prev_spike_times = None
self.spike_event_times = None
self.prev_spike_event_times = None
self._spike_recording_data = None
self._spike_event_recording_data = None

@property
def current_spikes(self):
Expand Down Expand Up @@ -591,6 +612,22 @@ def load(self, num_recording_timesteps):
# Load neuron extra global params
self._load_egp()

def unload(self):
self.spikes = None
self.spike_count = None
self.spike_events = None
self.spike_event_count = None
self.spike_que_ptr = None
self.spike_times = None
self.prev_spike_times = None
self.spike_event_times = None
self.prev_spike_event_times = None
self._spike_recording_data = None
self._spike_event_recording_data = None

self._unload_vars()
self._unload_egps()

def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()
Expand Down Expand Up @@ -1293,6 +1330,19 @@ def load_init_egps(self):
self._load_var_init_egps(self.pre_vars)
self._load_var_init_egps(self.post_vars)

def unload(self):
self._ind = None
self._row_lengths = None
self.in_syn = None

self._unload_vars()
self._unload_vars(self.pre_vars)
self._unload_vars(self.post_vars)
self._unload_vars(self.psm_vars)
self._unload_egps()
self._unload_egps(self.psm_extra_global_params)
self._unload_egps(self.connectivity_extra_global_params)

def reinitialise(self):
"""Reinitialise synapse group"""
# If population has individual synapse variables
Expand Down Expand Up @@ -1416,6 +1466,10 @@ def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()

def unload(self):
self._unload_vars()
self._unload_egps()

def reinitialise(self):
"""Reinitialise current source"""
# Reinitialise current source state variables
Expand Down Expand Up @@ -1547,6 +1601,10 @@ def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()

def unload(self):
self._unload_vars()
self._unload_egps()

def reinitialise(self):
"""Reinitialise custom update"""
# If this is a custom weight update
Expand Down
25 changes: 24 additions & 1 deletion pygenn/genn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def load(self, path_to_model="./", num_recording_timesteps=None):
# Allocate recording buffers
self._slm.allocate_recording_buffers(num_recording_timesteps)

# Loop through synapse populations and load any
# Loop through neuron populations and load any
# extra global parameters required for initialization
for pop_data in itervalues(self.neuron_populations):
pop_data.load_init_egps()
Expand Down Expand Up @@ -690,6 +690,29 @@ def load(self, path_to_model="./", num_recording_timesteps=None):
self._loaded = True
self._built = True

def unload(self):
# Loop through custom updates and unload
for cu_data in itervalues(self.custom_updates):
cu_data.unload()

# Loop through current sources and unload
for src_data in itervalues(self.current_sources):
src_data.unload()

# Loop through synapse populations and unload
for pop_data in itervalues(self.synapse_populations):
pop_data.unload()

# Loop through neuron populations and unload
for pop_data in itervalues(self.neuron_populations):
pop_data.unload()

# Close shared library model
self._slm.close()

# Clear loaded flag
self._loaded = False

def reinitialise(self):
"""reinitialise model to its original state without re-loading"""
if not self._loaded:
Expand Down
45 changes: 36 additions & 9 deletions userproject/include/sharedLibraryModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,8 @@ class SharedLibraryModel

virtual ~SharedLibraryModel()
{
// Close model library if loaded successfully
if(m_Library) {
freeMem();
#ifdef _WIN32
FreeLibrary(m_Library);
#else
dlclose(m_Library);
#endif
}
// Close model library
close();
}

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -112,6 +105,40 @@ class SharedLibraryModel
}
}

void close()
{
if(m_Library) {
freeMem();
#ifdef _WIN32
FreeLibrary(m_Library);
#else
dlclose(m_Library);
#endif
m_Library = nullptr;
}

// Null all pointers
m_AllocateMem = nullptr;
m_AllocateRecordingBuffers = nullptr;
m_FreeMem = nullptr;
m_GetFreeDeviceMemBytes = nullptr;
m_Initialize = nullptr;
m_InitializeSparse = nullptr;
m_StepTime = nullptr;
m_PullRecordingBuffersFromDevice = nullptr;
m_NCCLGenerateUniqueID = nullptr;
m_NCCLGetUniqueID = nullptr;
m_NCCLInitCommunicator = nullptr;
m_NCCLUniqueIDBytes = nullptr;
m_T = nullptr;
m_Timestep = nullptr;

// Empty all dictionaries
m_PopulationVars.clear();
m_PopulationEPGs.clear();
m_CustomUpdates.clear();
}

void allocateExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count)
{
// Get EGP functions and check allocate exists
Expand Down