diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py index 17edcff31..d04c3aa2e 100644 --- a/brainpy/_add_deprecations.py +++ b/brainpy/_add_deprecations.py @@ -88,6 +88,16 @@ # neurons 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.dyn.NeuDyn', NeuDyn), + # projections + 'ProjAlignPostMg1': ('brainpy.dyn.ProjAlignPostMg1', 'brainpy.dyn.HalfProjAlignPostMg', dyn.HalfProjAlignPostMg), + 'ProjAlignPostMg2': ('brainpy.dyn.ProjAlignPostMg2', 'brainpy.dyn.FullProjAlignPostMg', dyn.FullProjAlignPostMg), + 'ProjAlignPost1': ('brainpy.dyn.ProjAlignPost1', 'brainpy.dyn.HalfProjAlignPost', dyn.HalfProjAlignPost), + 'ProjAlignPost2': ('brainpy.dyn.ProjAlignPost2', 'brainpy.dyn.FullProjAlignPost', dyn.FullProjAlignPost), + 'ProjAlignPreMg1': ('brainpy.dyn.ProjAlignPreMg1', 'brainpy.dyn.FullProjAlignPreSDMg', dyn.FullProjAlignPreSDMg), + 'ProjAlignPreMg2': ('brainpy.dyn.ProjAlignPreMg2', 'brainpy.dyn.FullProjAlignPreDSMg', dyn.FullProjAlignPreDSMg), + 'ProjAlignPre1': ('brainpy.dyn.ProjAlignPre1', 'brainpy.dyn.FullProjAlignPreSD', dyn.FullProjAlignPreSD), + 'ProjAlignPre2': ('brainpy.dyn.ProjAlignPre2', 'brainpy.dyn.FullProjAlignPreDS', dyn.FullProjAlignPreDS), + # synapses 'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn), 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP), diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index 97e612097..f9145a94b 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -61,7 +61,7 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode): where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. .. versionadded:: 2.1.9 - Model the conductance-based neuron model. + Modeling the conductance-based neuron model. Parameters ---------- @@ -117,7 +117,7 @@ def __init__( def derivative(self, V, t, I): # synapses - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) # channels for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): I = I + ch.current(V) @@ -140,7 +140,7 @@ def update(self, x=None): x = x * (1e-3 / self.A) # integral - V = self.integral(self.V.value, share['t'], x, share['dt']) + V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs() # check whether the children channels have the correct parents. channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique() @@ -176,7 +176,7 @@ def derivative(self, V, t, I): def update(self, x=None): # inputs x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -384,7 +384,7 @@ def reset_state(self, batch_size=None, **kwargs): self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) def dV(self, V, t, m, h, n, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) n2 = n * n I_K = (self.gK * n2 * n2) * (V - self.EK) @@ -402,6 +402,7 @@ def update(self, x=None): x = 0. if x is None else x V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt) + V += self.sum_delta_inputs() self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m @@ -532,7 +533,7 @@ def derivative(self): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -662,7 +663,7 @@ def reset_state(self, batch_or_mode=None, **kwargs): self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode) def dV(self, V, t, W, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) I_K = self.g_K * W * (V - self.V_K) @@ -685,6 +686,7 @@ def update(self, x=None): dt = share.load('dt') x = 0. if x is None else x V, W = self.integral(self.V, self.W, t, x, dt) + V += self.sum_delta_inputs() spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.W.value = W @@ -761,7 +763,7 @@ def dV(self, V, t, W, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -951,7 +953,7 @@ def dn(self, n, t, V): return self.phi * dndt def dV(self, V, t, h, n, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) IK = self.gK * n ** 4 * (V - self.EK) IL = self.gL * (V - self.EL) @@ -968,6 +970,7 @@ def update(self, x=None): x = 0. if x is None else x V, h, n = self.integral(self.V, self.h, self.n, t, x, dt) + V += self.sum_delta_inputs() self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h @@ -1091,5 +1094,5 @@ def dV(self, V, t, h, n, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 988c915ac..11934d9dc 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -5,12 +5,12 @@ import brainpy.math as bm from brainpy._src.context import share +from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc +from brainpy._src.dyn.neurons.base import GradNeuDyn from brainpy._src.initialize import ZeroInit, OneInit from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType, Sharding -from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc -from brainpy._src.dyn.neurons.base import GradNeuDyn __all__ = [ 'IF', @@ -119,7 +119,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (-V + self.V_rest + self.R * I) / self.tau def reset_state(self, batch_size=None, **kwargs): @@ -132,7 +132,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - self.V.value = self.integral(self.V.value, t, x, dt) + self.V.value = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() return self.V.value @@ -146,7 +146,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -252,7 +252,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (-V + self.V_rest + self.R * I) / self.tau def reset_state(self, batch_size=None, **kwargs): @@ -265,7 +265,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -337,7 +337,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -464,7 +464,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -552,7 +552,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -723,7 +723,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau return dvdt @@ -738,7 +738,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -880,7 +880,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -994,6 +994,7 @@ class ExpIFRefLTC(ExpIFLTC): %s """ + def __init__( self, size: Shape, @@ -1076,7 +1077,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -1221,6 +1222,7 @@ class ExpIFRef(ExpIFRefLTC): %s %s """ + def derivative(self, V, t, I): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau @@ -1228,7 +1230,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -1400,7 +1402,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, w, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau return dVdt @@ -1425,6 +1427,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -1559,7 +1562,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -1757,6 +1760,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -1901,7 +1905,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2040,7 +2044,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau return dVdt @@ -2054,7 +2058,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -2166,7 +2170,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2330,7 +2334,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -2444,14 +2448,13 @@ class QuaIFRef(QuaIFRefLTC): %s """ - def derivative(self, V, t, I): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau return dVdt def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2609,7 +2612,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, w, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau return dVdt @@ -2633,6 +2636,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -2756,7 +2760,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2939,6 +2943,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -3072,7 +3077,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -3279,7 +3284,7 @@ def dVth(self, V_th, t, V): return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) def dV(self, V, t, I1, I2, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau @property @@ -3300,6 +3305,7 @@ def update(self, x=None): # integrate membrane potential I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -3452,7 +3458,7 @@ def dV(self, V, t, I1, I2, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -3573,7 +3579,6 @@ class GifRefLTC(GifLTC): %s """ - def __init__( self, size: Shape, @@ -3680,6 +3685,7 @@ def update(self, x=None): # integrate membrane potential I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -3840,13 +3846,12 @@ class GifRef(GifRefLTC): %s """ - def dV(self, V, t, I1, I2, I): return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -4012,7 +4017,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, u, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I return dVdt @@ -4040,6 +4045,7 @@ def update(self, x=None): # integrate membrane potential V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -4161,7 +4167,7 @@ def dV(self, V, t, u, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -4351,6 +4357,7 @@ def update(self, x=None): # integrate membrane potential V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -4485,11 +4492,11 @@ def dV(self, V, t, u, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) -Izhikevich.__doc__ = Izhikevich.__doc__ %(pneu_doc, dpneu_doc) -IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ %(pneu_doc, dpneu_doc, ref_doc) -IzhikevichRef.__doc__ = IzhikevichRef.__doc__ %(pneu_doc, dpneu_doc, ref_doc) -IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ %() +Izhikevich.__doc__ = Izhikevich.__doc__ % (pneu_doc, dpneu_doc) +IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +IzhikevichRef.__doc__ = IzhikevichRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % () diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py index 7cf4f98b8..812375787 100644 --- a/brainpy/_src/dyn/others/common.py +++ b/brainpy/_src/dyn/others/common.py @@ -77,7 +77,7 @@ def update(self, inp=None): dt = share.load('dt') self.x.value = self.integral(self.x.value, t, dt) if inp is None: inp = 0. - inp = self.sum_inputs(self.x.value, init=inp) + inp = self.sum_current_inputs(self.x.value, init=inp) self.x += inp return self.x.value diff --git a/brainpy/_src/dyn/outs/outputs.py b/brainpy/_src/dyn/outs/outputs.py index 5dc54a232..8171367d7 100644 --- a/brainpy/_src/dyn/outs/outputs.py +++ b/brainpy/_src/dyn/outs/outputs.py @@ -82,7 +82,7 @@ def __init__( super().__init__(name=name, scaling=scaling) def update(self, conductance, potential=None): - return self.std_scaling(conductance) + return conductance class MgBlock(SynOut): @@ -138,5 +138,5 @@ def __init__( self.beta = init.parameter(beta, np.shape(beta), sharding=sharding) def update(self, conductance, potential): - return conductance *\ - (self.E - potential) / (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) + norm = (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) + return conductance * (self.E - potential) / norm diff --git a/brainpy/_src/dyn/projections/__init__.py b/brainpy/_src/dyn/projections/__init__.py index 8a7040824..e69de29bb 100644 --- a/brainpy/_src/dyn/projections/__init__.py +++ b/brainpy/_src/dyn/projections/__init__.py @@ -1,5 +0,0 @@ - -from .aligns import * -from .conn import * -from .others import * -from .inputs import * diff --git a/brainpy/_src/dyn/projections/align_post.py b/brainpy/_src/dyn/projections/align_post.py new file mode 100644 index 000000000..b5679dc7d --- /dev/null +++ b/brainpy/_src/dyn/projections/align_post.py @@ -0,0 +1,490 @@ +from typing import Optional, Callable, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (delay_identifier, + register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost) + +__all__ = [ + 'HalfProjAlignPostMg', 'FullProjAlignPostMg', + 'HalfProjAlignPost', 'FullProjAlignPost', + +] + + +def get_post_repr(out_label, syn, out): + return f'{out_label} // {syn.identifier} // {out.identifier}' + + +def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): + # synapse and output initialization + _post_repr = get_post_repr(out_label, syn_desc, out_desc) + if not post.has_bef_update(_post_repr): + syn_cls = syn_desc() + out_cls = out_desc() + + # synapse and output initialization + post.add_inp_fun(proj_name, out_cls, label=out_label) + post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) + syn = post.get_bef_update(_post_repr).syn + out = post.get_bef_update(_post_repr).out + return syn, out + + +class _AlignPost(DynamicalSystem): + def __init__(self, + syn: Callable, + out: JointType[DynamicalSystem, BindCondData]): + super().__init__() + self.syn = syn + self.out = out + + def update(self, *args, **kwargs): + self.out.bind_cond(self.syn(*args, **kwargs)) + + def reset_state(self, *args, **kwargs): + pass + + +class HalfProjAlignPostMg(Projection): + r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging. + + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + **Code Examples** + + To define an E/I balanced network model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + out_label: str. The prefix of the output function. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn + self.refs['out'] = out + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + +class FullProjAlignPostMg(Projection): + """Full-chain synaptic projection with the align-post reduction and the automatic synapse merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. + + **Code Examples** + + To define an E/I balanced network model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ni, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ne, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn # invisible to ``self.node()`` + self.refs['out'] = out # invisible to ``self.node()`` + # unify the access + self.refs['comm'] = comm + self.refs['delay'] = pre.get_aft_update(delay_identifier) + + def update(self): + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + +class HalfProjAlignPost(Projection): + """Defining the half-part of synaptic projection with the align-post reduction. + + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + To simulate an E/I balanced network: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + self.out = out + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # reference + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['post'] = post + self.refs['syn'] = syn + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self, x): + current = self.comm(x) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return current + + +class FullProjAlignPost(Projection): + """Full-chain synaptic projection with the align-post reduction. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. + + To simulate and define an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + # unify the access + self.refs['delay'] = delay_cls + self.refs['comm'] = comm + self.refs['syn'] = syn + + def update(self): + x = self.refs['delay'].at(self.name) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return g diff --git a/brainpy/_src/dyn/projections/align_pre.py b/brainpy/_src/dyn/projections/align_pre.py new file mode 100644 index 000000000..356de0a6d --- /dev/null +++ b/brainpy/_src/dyn/projections/align_pre.py @@ -0,0 +1,583 @@ +from typing import Optional, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData) +from .utils import _get_return + +__all__ = [ + 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg', + 'FullProjAlignPreSD', 'FullProjAlignPreDS', +] + + +def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): + _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' + if not delay_cls.has_bef_update(_syn_id): + # delay + delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) + # synapse + syn_cls = syn_desc() + # add to "after_updates" + delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) + syn = delay_cls.get_bef_update(_syn_id).syn + return syn + + +class _AlignPreMg(DynamicalSystem): + def __init__(self, access, syn): + super().__init__() + self.access = access + self.syn = syn + + def update(self, *args, **kwargs): + return self.syn(self.access()) + + def reset_state(self, *args, **kwargs): + pass + + +def align_pre1_add_bef_update(syn_desc, pre): + _syn_id = f'{syn_desc.identifier} // Delay' + if not pre.has_aft_update(_syn_id): + # "syn_cls" needs an instance of "ProjAutoDelay" + syn_cls: SupportAutoDelay = syn_desc() + delay_cls = init_delay_by_return(syn_cls.return_info()) + # add to "after_updates" + pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) + delay_cls: Delay = pre.get_aft_update(_syn_id).delay + syn = pre.get_aft_update(_syn_id).syn + return delay_cls, syn + + +class _AlignPre(DynamicalSystem): + def __init__(self, syn, delay=None): + super().__init__() + self.syn = syn + self.delay = delay + + def update(self, x): + if self.delay is None: + return x >> self.syn + else: + return x >> self.syn >> self.delay + + def reset_state(self, *args, **kwargs): + pass + + +class FullProjAlignPreSDMg(Projection): + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg``facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn_cls + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreDSMg(Projection): + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: ParamDescriber[DynamicalSystem], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + + # synapse initialization + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to `self.nodes()` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['syn'] = syn_cls + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self): + x = _get_return(self.refs['syn'].return_info()) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreSD(Projection): + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. + + Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS``facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls = init_delay_by_return(syn.return_info()) + delay_cls.register_entry(self.name, delay) + pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreDS(Projection): + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + + Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: DynamicalSystem, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, DynamicalSystem) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + # unify the access + self.refs['syn'] = syn + self.refs['comm'] = comm + + def update(self): + spk = self.refs['delay'].at(self.name) + g = self.comm(self.syn(spk)) + self.refs['out'].bind_cond(g) + return g diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py deleted file mode 100644 index 2616e928b..000000000 --- a/brainpy/_src/dyn/projections/aligns.py +++ /dev/null @@ -1,1053 +0,0 @@ -from typing import Optional, Callable, Union - -from brainpy import math as bm, check -from brainpy._src.delay import (Delay, DelayAccess, delay_identifier, - init_delay_by_return, register_delay_by_return) -from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.mixin import (JointType, ParamDescriber, ReturnInfo, - SupportAutoDelay, BindCondData, AlignPost) - -__all__ = [ - 'VanillaProj', - 'ProjAlignPostMg1', 'ProjAlignPostMg2', - 'ProjAlignPost1', 'ProjAlignPost2', - 'ProjAlignPreMg1', 'ProjAlignPreMg2', - 'ProjAlignPre1', 'ProjAlignPre2', -] - - -def get_post_repr(out_label, syn, out): - return f'{out_label} // {syn.identifier} // {out.identifier}' - - -def add_inp_fun(out_label, proj_name, out, post): - # synapse and output initialization - if out_label is None: - out_name = proj_name - else: - out_name = f'{out_label} // {proj_name}' - post.add_inp_fun(out_name, out) - - -def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): - # synapse and output initialization - _post_repr = get_post_repr(out_label, syn_desc, out_desc) - if not post.has_bef_update(_post_repr): - syn_cls = syn_desc() - out_cls = out_desc() - - # synapse and output initialization - if out_label is None: - out_name = proj_name - else: - out_name = f'{out_label} // {proj_name}' - post.add_inp_fun(out_name, out_cls) - post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) - syn = post.get_bef_update(_post_repr).syn - out = post.get_bef_update(_post_repr).out - return syn, out - - -def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): - _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' - if not delay_cls.has_bef_update(_syn_id): - # delay - delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) - # synapse - syn_cls = syn_desc() - # add to "after_updates" - delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) - syn = delay_cls.get_bef_update(_syn_id).syn - return syn - - -def align_pre1_add_bef_update(syn_desc, pre): - _syn_id = f'{syn_desc.identifier} // Delay' - if not pre.has_aft_update(_syn_id): - # "syn_cls" needs an instance of "ProjAutoDelay" - syn_cls: SupportAutoDelay = syn_desc() - delay_cls = init_delay_by_return(syn_cls.return_info()) - # add to "after_updates" - pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) - delay_cls: Delay = pre.get_aft_update(_syn_id).delay - syn = pre.get_aft_update(_syn_id).syn - return delay_cls, syn - - -class _AlignPre(DynamicalSystem): - def __init__(self, syn, delay=None): - super().__init__() - self.syn = syn - self.delay = delay - - def update(self, x): - if self.delay is None: - return x >> self.syn - else: - return x >> self.syn >> self.delay - - def reset_state(self, *args, **kwargs): - pass - - -class _AlignPost(DynamicalSystem): - def __init__(self, - syn: Callable, - out: JointType[DynamicalSystem, BindCondData]): - super().__init__() - self.syn = syn - self.out = out - - def update(self, *args, **kwargs): - self.out.bind_cond(self.syn(*args, **kwargs)) - - def reset_state(self, *args, **kwargs): - pass - - -class _AlignPreMg(DynamicalSystem): - def __init__(self, access, syn): - super().__init__() - self.access = access - self.syn = syn - - def update(self, *args, **kwargs): - return self.syn(self.access()) - - def reset_state(self, *args, **kwargs): - pass - - -def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError - - -class VanillaProj(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. - - **Code Examples** - - To simulate an E/I balanced network model: - - .. code-block:: - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=3200, tau=5.) - self.syn2 = bp.dyn.Expon(size=800, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:3200])) - self.I(self.syn2(spk[3200:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # output initialization - post.add_inp_fun(self.name, out) - - # references - self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` - self.refs['comm'] = comm # unify the access - - def update(self, x): - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPostMg1(Projection): - r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - **Code Examples** - - To define an E/I balanced network model. - - .. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - out_label: str. The prefix of the output function. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn - self.refs['out'] = out - self.refs['comm'] = comm # unify the access - - def update(self, x): - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current - - -class ProjAlignPostMg2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - **Code Examples** - - To define an E/I balanced network model. - - .. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPostMg2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPostMg2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ni, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPostMg2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ne, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPostMg2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn # invisible to ``self.node()`` - self.refs['out'] = out # invisible to ``self.node()`` - # unify the access - self.refs['comm'] = comm - self.refs['delay'] = pre.get_aft_update(delay_identifier) - - def update(self): - x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current - - -class ProjAlignPost1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - To simulate an E/I balanced network: - - .. code-block:: - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=4000, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=4000, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - self.out = out - - # synapse and output initialization - add_inp_fun(out_label, self.name, out, post) - - # reference - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['post'] = post - self.refs['syn'] = syn - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - - def update(self, x): - current = self.comm(x) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return current - - -class ProjAlignPost2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - To simulate and define an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - # unify the access - self.refs['delay'] = delay_cls - self.refs['comm'] = comm - self.refs['syn'] = syn - - def update(self): - x = self.refs['delay'].at(self.name) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return g - - -class ProjAlignPreMg1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn_cls - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPreMg2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: ParamDescriber[DynamicalSystem], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - - # synapse initialization - syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to `self.nodes()` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['syn'] = syn_cls - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - - def update(self): - x = _get_return(self.refs['syn'].return_info()) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPre1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls = init_delay_by_return(syn.return_info()) - delay_cls.register_entry(self.name, delay) - pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPre2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: DynamicalSystem, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, DynamicalSystem) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - # unify the access - self.refs['syn'] = syn - self.refs['comm'] = comm - - def update(self): - spk = self.refs['delay'].at(self.name) - g = self.comm(self.syn(spk)) - self.refs['out'].bind_cond(g) - return g diff --git a/brainpy/_src/dyn/projections/delta.py b/brainpy/_src/dyn/projections/delta.py new file mode 100644 index 000000000..19e4938cb --- /dev/null +++ b/brainpy/_src/dyn/projections/delta.py @@ -0,0 +1,210 @@ +from typing import Optional, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (delay_identifier, register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, SupportAutoDelay) + +__all__ = [ + 'HalfProjDelta', 'FullProjDelta', +] + + +class _Delta: + def __init__(self): + self._cond = None + + def bind_cond(self, cond): + self._cond = cond + + def __call__(self, *args, **kwargs): + r = self._cond + return r + + +class HalfProjDelta(Projection): + """Defining the half-part of the synaptic projection for the Delta synapse model. + + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. + + The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``. + Therefore, the model's ``update`` function needs the manual providing of the spiking input. + + **Model Descriptions** + + .. math:: + + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. + + + **Code Examples** + + .. code-block:: + + import brainpy as bp + import brainpy.math as bm + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value + + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) + + Args: + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') + + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + + def update(self, x): + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current + + +class FullProjDelta(Projection): + """Full-chain of the synaptic projection for the Delta synapse model. + + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``post``. + + **Model Descriptions** + + .. math:: + + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. + + + **Code Examples** + + .. code-block:: + + import brainpy as bp + import brainpy.math as bm + + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value + + + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') + + # references + self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + self.refs['delay'] = pre.get_aft_update(delay_identifier) + + def update(self): + # get delay + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current diff --git a/brainpy/_src/dyn/projections/inputs.py b/brainpy/_src/dyn/projections/inputs.py index f0001988b..dd1e1e3df 100644 --- a/brainpy/_src/dyn/projections/inputs.py +++ b/brainpy/_src/dyn/projections/inputs.py @@ -1,96 +1,167 @@ -from typing import Optional, Any +import numbers +from typing import Any +from typing import Union, Optional -from brainpy import math as bm +from brainpy import check, math as bm +from brainpy._src.context import share from brainpy._src.dynsys import Dynamic +from brainpy._src.dynsys import Projection from brainpy._src.mixin import SupportAutoDelay from brainpy.types import Shape __all__ = [ - 'InputVar', + 'InputVar', + 'PoissonInput', ] class InputVar(Dynamic, SupportAutoDelay): - """Define an input variable. + """Define an input variable. - Example:: + Example:: + + import brainpy as bp - import brainpy as bp - - class Exponential(bp.Projection): - def __init__(self, pre, post, prob, g_max, tau, E=0.): - super().__init__() - self.proj = bp.dyn.ProjAlignPostMg2( - pre=pre, - delay=None, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - syn=bp.dyn.Expon.desc(post.num, tau=tau), - out=bp.dyn.COBA.desc(E=E), - post=post, - ) - - - class EINet(bp.DynSysGroup): - def __init__(self, num_exc, num_inh, method='exp_auto'): - super(EINet, self).__init__() - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.), method=method) - self.E = bp.dyn.LifRef(num_exc, **pars) - self.I = bp.dyn.LifRef(num_inh, **pars) - - # synapses - w_e = 0.6 # excitatory synaptic weight - w_i = 6.7 # inhibitory synaptic weight - - # Neurons connect to each other randomly with a connection probability of 2% - self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) - self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) - self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) - self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) - - # define input variables given to E/I populations - self.Ein = bp.dyn.InputVar(self.E.varshape) - self.Iin = bp.dyn.InputVar(self.I.varshape) - self.E.add_inp_fun('', self.Ein) - self.I.add_inp_fun('', self.Iin) - - - net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method - runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) - runner.run(100.) - - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], - title='Spikes of Excitatory Neurons', show=True) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], - title='Spikes of Inhibitory Neurons', show=True) - - - """ - def __init__( - self, - size: Shape, - keep_size: bool = False, - sharding: Optional[Any] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - method: str = 'exp_auto' - ): - super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.input = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, *args, **kwargs): - return self.input.value - - def return_info(self): - return self.input - - def clear_input(self, *args, **kwargs): - self.reset_state(self.mode) + class Exponential(bp.Projection): + def __init__(self, pre, post, prob, g_max, tau, E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg2( + pre=pre, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + syn=bp.dyn.Expon.desc(post.num, tau=tau), + out=bp.dyn.COBA.desc(E=E), + post=post, + ) + + + class EINet(bp.DynSysGroup): + def __init__(self, num_exc, num_inh, method='exp_auto'): + super(EINet, self).__init__() + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), method=method) + self.E = bp.dyn.LifRef(num_exc, **pars) + self.I = bp.dyn.LifRef(num_inh, **pars) + + # synapses + w_e = 0.6 # excitatory synaptic weight + w_i = 6.7 # inhibitory synaptic weight + + # Neurons connect to each other randomly with a connection probability of 2% + self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) + self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) + self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) + self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) + + # define input variables given to E/I populations + self.Ein = bp.dyn.InputVar(self.E.varshape) + self.Iin = bp.dyn.InputVar(self.I.varshape) + self.E.add_inp_fun('', self.Ein) + self.I.add_inp_fun('', self.Iin) + + + net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method + runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) + runner.run(100.) + + # visualization + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], + title='Spikes of Excitatory Neurons', show=True) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], + title='Spikes of Inhibitory Neurons', show=True) + + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + sharding: Optional[Any] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + method: str = 'exp_auto' + ): + super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.input = self.init_variable(bm.zeros, batch_or_mode) + + def update(self, *args, **kwargs): + return self.input.value + + def return_info(self): + return self.input + + def clear_input(self, *args, **kwargs): + self.reset_state(self.mode) + + +class PoissonInput(Projection): + """Poisson Input to the given :py:class:`~.Variable`. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Args: + target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. + num_input: The number of inputs. + freq: The frequency of each of the inputs. Must be a scalar. + weight: The synaptic weight. Must be a scalar. + name: The target name. + mode: The computing mode. + """ + + def __init__( + self, + target_var: bm.Variable, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + if not isinstance(target_var, bm.Variable): + raise TypeError(f'"target_var" must be an instance of Variable. ' + f'But we got {type(target_var)}: {target_var}') + self.target_var = target_var + self.num_input = check.is_integer(num_input, min_bound=1) + self.freq = check.is_float(freq, min_bound=0., allow_int=True) + self.weight = check.is_float(weight, allow_int=True) + + def reset_state(self, *args, **kwargs): + pass + + def update(self): + p = self.freq * share['dt'] / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + + if isinstance(share['dt'], numbers.Number): # dt is not traced + if (a > 5) and (b > 5): + inp = bm.random.normal(a, b * p, self.target_var.shape) + else: + inp = bm.random.binomial(self.num_input, p, self.target_var.shape) + + else: # dt is traced + inp = bm.cond((a > 5) * (b > 5), + lambda: bm.random.normal(a, b * p, self.target_var.shape), + lambda: bm.random.binomial(self.num_input, p, self.target_var.shape)) + + # inp = bm.sharding.partition(inp, self.target_var.sharding) + self.target_var += inp * self.weight + + def __repr__(self): + return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py deleted file mode 100644 index 72a77298f..000000000 --- a/brainpy/_src/dyn/projections/others.py +++ /dev/null @@ -1,81 +0,0 @@ -import numbers -import warnings -from typing import Union, Optional - -from brainpy import check, math as bm -from brainpy._src.context import share -from brainpy._src.dynsys import Projection - -__all__ = [ - 'PoissonInput', -] - - -class PoissonInput(Projection): - """Poisson Input to the given :py:class:`~.Variable`. - - Adds independent Poisson input to a target variable. For large - numbers of inputs, this is much more efficient than creating a - `PoissonGroup`. The synaptic events are generated randomly during the - simulation and are not preloaded and stored in memory. All the inputs must - target the same variable, have the same frequency and same synaptic weight. - All neurons in the target variable receive independent realizations of - Poisson spike trains. - - Args: - target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. - num_input: The number of inputs. - freq: The frequency of each of the inputs. Must be a scalar. - weight: The synaptic weight. Must be a scalar. - name: The target name. - mode: The computing mode. - """ - - def __init__( - self, - target_var: bm.Variable, - num_input: int, - freq: Union[int, float], - weight: Union[int, float], - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - seed=None - ): - super().__init__(name=name, mode=mode) - - if seed is not None: - warnings.warn('') - - if not isinstance(target_var, bm.Variable): - raise TypeError(f'"target_var" must be an instance of Variable. ' - f'But we got {type(target_var)}: {target_var}') - self.target_var = target_var - self.num_input = check.is_integer(num_input, min_bound=1) - self.freq = check.is_float(freq, min_bound=0., allow_int=True) - self.weight = check.is_float(weight, allow_int=True) - - def reset_state(self, *args, **kwargs): - pass - - def update(self): - p = self.freq * share['dt'] / 1e3 - a = self.num_input * p - b = self.num_input * (1 - p) - - if isinstance(share['dt'], numbers.Number): # dt is not traced - if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_var.shape) - else: - inp = bm.random.binomial(self.num_input, p, self.target_var.shape) - - else: # dt is traced - inp = bm.cond((a > 5) * (b > 5), - lambda: bm.random.normal(a, b * p, self.target_var.shape), - lambda: bm.random.binomial(self.num_input, p, self.target_var.shape), - ()) - - # inp = bm.sharding.partition(inp, self.target_var.sharding) - self.target_var += inp * self.weight - - def __repr__(self): - return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 3fb3c1232..d36074b9c 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -7,8 +7,9 @@ from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType -from .aligns import (_get_return, align_post_add_bef_update, - align_pre2_add_bef_update, add_inp_fun) +from .align_post import (align_post_add_bef_update, ) +from .align_pre import (align_pre2_add_bef_update, ) +from .utils import (_get_return, ) __all__ = [ 'STDP_Song2000', @@ -165,7 +166,7 @@ def __init__( else: syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') out_cls = out() - add_inp_fun(out_label, self.name, out_cls, post) + post.add_inp_fun(self.name, out_cls, label=out_label) # references self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index a4173c7ba..b8884f327 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -86,7 +86,7 @@ def update(self, I_pre, I_post): conductance = self.syn.refs['syn'].g Apre = self.syn.refs['pre_trace'].g Apost = self.syn.refs['post_trace'].g - current = self.post.sum_inputs(self.post.V) + current = self.post.sum_current_inputs(self.post.V) if comm_method == 'dense': w = self.syn.comm.W.flatten() else: diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 32b072e5a..90500a26f 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -19,7 +19,7 @@ def __init__(self, scale=1., inp=20., delay=None): prob = 80 / (4000 * scale) - self.E2I = bp.dyn.ProjAlignPreMg1( + self.E2I = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=delay, @@ -27,7 +27,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=0.), post=self.I, ) - self.E2E = bp.dyn.ProjAlignPreMg1( + self.E2E = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=delay, @@ -35,7 +35,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=0.), post=self.E, ) - self.I2E = bp.dyn.ProjAlignPreMg1( + self.I2E = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=delay, @@ -43,7 +43,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPreMg1( + self.I2I = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=delay, @@ -90,7 +90,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): prob = 80 / (4000 * scale) - self.E2E = bp.dyn.ProjAlignPostMg2( + self.E2E = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), @@ -98,7 +98,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=0.), post=self.E, ) - self.E2I = bp.dyn.ProjAlignPostMg2( + self.E2I = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), @@ -106,7 +106,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=0.), post=self.I, ) - self.I2E = bp.dyn.ProjAlignPostMg2( + self.I2E = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), @@ -114,7 +114,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPostMg2( + self.I2I = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), @@ -163,14 +163,14 @@ def __init__(self, scale=1.): self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) def update(self, input): spk = self.delay.at('I') @@ -198,30 +198,30 @@ def __init__(self, scale, delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -292,30 +292,30 @@ def __init__(self, scale=1., delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -350,30 +350,30 @@ def __init__(self, scale=1., delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() diff --git a/brainpy/_src/dyn/projections/tests/test_delta.py b/brainpy/_src/dyn/projections/tests/test_delta.py new file mode 100644 index 000000000..f4d21b643 --- /dev/null +++ b/brainpy/_src/dyn/projections/tests/test_delta.py @@ -0,0 +1,51 @@ +import matplotlib.pyplot as plt + +import brainpy as bp +import brainpy.math as bm + + +class NetForHalfProj(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value + + +def test1(): + net = NetForHalfProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=False) + plt.close('all') + + +class NetForFullProj(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value + + +def test2(): + net = NetForFullProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=False) + plt.close('all') + + diff --git a/brainpy/_src/dyn/projections/utils.py b/brainpy/_src/dyn/projections/utils.py new file mode 100644 index 000000000..44a2273a4 --- /dev/null +++ b/brainpy/_src/dyn/projections/utils.py @@ -0,0 +1,12 @@ +from brainpy import math as bm +from brainpy._src.mixin import ReturnInfo + + +def _get_return(return_info): + if isinstance(return_info, bm.Variable): + return return_info.value + elif isinstance(return_info, ReturnInfo): + return return_info.get_data() + else: + raise NotImplementedError + diff --git a/brainpy/_src/dyn/projections/vanilla.py b/brainpy/_src/dyn/projections/vanilla.py new file mode 100644 index 000000000..15773d231 --- /dev/null +++ b/brainpy/_src/dyn/projections/vanilla.py @@ -0,0 +1,83 @@ +from typing import Optional + +from brainpy import math as bm, check +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, BindCondData) + +__all__ = [ + 'VanillaProj', +] + + +class VanillaProj(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. + + **Code Examples** + + To simulate an E/I balanced network model: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=3200, tau=5.) + self.syn2 = bp.dyn.Expon(size=800, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:3200])) + self.I(self.syn2(spk[3200:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # output initialization + post.add_inp_fun(self.name, out) + + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index 4a6b9ddb6..5fad9482d 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -10,7 +10,6 @@ from brainpy.types import ArrayType __all__ = [ - 'Delta', 'Expon', 'DualExpon', 'DualExponV2', @@ -21,69 +20,6 @@ ] -class Delta(SynDyn, AlignPost): - r"""Delta synapse model. - - **Model Descriptions** - - The single exponential decay synapse model assumes the release of neurotransmitter, - its diffusion across the cleft, the receptor binding, and channel opening all happen - very quickly, so that the channels instantaneously jump from the closed to the open state. - Therefore, its expression is given by - - .. math:: - - g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} - - where :math:`\tau_{delay}` is the time constant of the synaptic state decay, - :math:`t_0` is the time of the pre-synaptic spike, - :math:`g_{\mathrm{max}}` is the maximal conductance. - - Accordingly, the differential form of the exponential synapse is given by - - .. math:: - - \begin{aligned} - & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). - \end{aligned} - - .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. - "The Synapse." Principles of Computational Modelling in Neuroscience. - Cambridge: Cambridge UP, 2011. 172-95. Print. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, x=None): - if x is not None: - self.g.value += x - return self.g.value - - def add_current(self, x): - self.g.value += x - - def return_info(self): - return self.g - - class Expon(SynDyn, AlignPost): r"""Exponential decay synapse model. @@ -1030,4 +966,4 @@ def return_info(self): lambda shape: self.u * self.x) -STP.__doc__ = STP.__doc__ % (pneu_doc,) \ No newline at end of file +STP.__doc__ = STP.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index a2bc1bdd5..55bac7111 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -6,7 +6,7 @@ from brainpy import math as bm from brainpy._src.connect import TwoEndConnector, One2One, All2All from brainpy._src.dnn import linear -from brainpy._src.dyn import projections +from brainpy._src.dyn.projections.conn import SynConn from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter @@ -29,7 +29,7 @@ class _SynapseComponent(DynamicalSystem): synaptic long-term plasticity, and others. """ '''Master of this component.''' - master: projections.SynConn + master: SynConn def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -50,9 +50,9 @@ def isregistered(self, val: bool): def reset_state(self, batch_size=None): pass - def register_master(self, master: projections.SynConn): - if not isinstance(master, projections.SynConn): - raise TypeError(f'master must be instance of {projections.SynConn.__name__}, but we got {type(master)}') + def register_master(self, master: SynConn): + if not isinstance(master, SynConn): + raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') if self.isregistered: raise ValueError(f'master has been registered, but we got another master going to be registered.') if hasattr(self, 'master') and self.master != master: @@ -90,7 +90,7 @@ def __init__( f'But we got {type(target_var)}') self.target_var: Optional[bm.Variable] = target_var - def register_master(self, master: projections.SynConn): + def register_master(self, master: SynConn): super().register_master(master) # initialize target variable to output @@ -125,7 +125,7 @@ def clone(self): return _NullSynOut() -class TwoEndConn(projections.SynConn): +class TwoEndConn(SynConn): """Base class to model synaptic connections. Parameters diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index ee1fb2b8f..a070a295a 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -91,7 +91,8 @@ def __init__( # Attribute for "SupportInputProj" # each instance of "SupportInputProj" should have a "cur_inputs" attribute - self.cur_inputs = bm.node_dict() + self.current_inputs = bm.node_dict() + self.delta_inputs = bm.node_dict() # the before- / after-updates used for computing # added after the version of 2.4.3 diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 6ac7f3a3d..323fe872c 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -21,7 +21,6 @@ DynamicalSystem = None delay_identifier, init_delay_by_return = None, None - __all__ = [ 'MixIn', 'ParamDesc', @@ -53,7 +52,6 @@ def _get_dynsys(): return DynamicalSystem - class MixIn(object): """Base MixIn object. @@ -378,55 +376,119 @@ def get_delay_var(self, name): class SupportInputProj(MixIn): """The :py:class:`~.MixIn` that receives the input projections. - Note that the subclass should define a ``cur_inputs`` attribute. + Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, + the input function utilities cannot be used. """ - cur_inputs: bm.node_dict + current_inputs: bm.node_dict + delta_inputs: bm.node_dict - def add_inp_fun(self, key: Any, fun: Callable): + def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): """Add an input function. Args: - key: The dict key. - fun: The function to generate inputs. + key: str. The dict key. + fun: Callable. The function to generate inputs. + label: str. The input label. + category: str. The input category, should be ``current`` (the current) or + ``delta`` (the delta synapse, indicating the delta function). """ if not callable(fun): raise TypeError('Must be a function.') - if key in self.cur_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.cur_inputs[key] = fun - def get_inp_fun(self, key): + key = self._input_label_repr(key, label) + if category == 'current': + if key in self.current_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.current_inputs[key] = fun + elif category == 'delta': + if key in self.delta_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.delta_inputs[key] = fun + else: + raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') + + def get_inp_fun(self, key: str): """Get the input function. Args: - key: The key. + key: str. The key. Returns: The input function which generates currents. """ - return self.cur_inputs.get(key) + if key in self.current_inputs: + return self.current_inputs[key] + elif key in self.delta_inputs: + return self.delta_inputs[key] + else: + raise ValueError(f'Unknown key: {key}') + + def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all current inputs by the defined input functions ``.current_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + label: str. The input label. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.current_inputs.items(): + init = init + out(*args, **kwargs) + else: + label_repr = self._input_label_start(label) + for key, out in self.current_inputs.items(): + if key.startswith(label_repr): + init = init + out(*args, **kwargs) + return init - def sum_inputs(self, *args, init=0., label=None, **kwargs): - """Summarize all inputs by the defined input functions ``.cur_inputs``. + def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all delta inputs by the defined input functions ``.delta_inputs``. Args: *args: The arguments for input functions. init: The initial input data. + label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: - for key, out in self.cur_inputs.items(): + for key, out in self.delta_inputs.items(): init = init + out(*args, **kwargs) else: - for key, out in self.cur_inputs.items(): - if key.startswith(label + ' // '): + label_repr = self._input_label_start(label) + for key, out in self.delta_inputs.items(): + if key.startswith(label_repr): init = init + out(*args, **kwargs) return init + @classmethod + def _input_label_start(cls, label: str): + # unify the input label repr. + return f'{label} // ' + + @classmethod + def _input_label_repr(cls, name: str, label: Optional[str] = None): + # unify the input label repr. + return name if label is None else (cls._input_label_start(label) + str(name)) + + # deprecated # + # ---------- # + + @property + def cur_inputs(self): + return self.current_inputs + + def sum_inputs(self, *args, **kwargs): + warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) + return self.sum_current_inputs(*args, **kwargs) + class SupportReturnInfo(MixIn): """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py index b2f4c5304..23e1a7485 100644 --- a/brainpy/dyn/projections.py +++ b/brainpy/dyn/projections.py @@ -1,24 +1,24 @@ - -from brainpy._src.dyn.projections.aligns import ( - VanillaProj, - ProjAlignPostMg1, - ProjAlignPostMg2, - ProjAlignPost1, - ProjAlignPost2, - ProjAlignPreMg1, - ProjAlignPreMg2, - ProjAlignPre1, - ProjAlignPre2, +from brainpy._src.dyn.projections.vanilla import VanillaProj +from brainpy._src.dyn.projections.delta import ( + HalfProjDelta, + FullProjDelta, +) +from brainpy._src.dyn.projections.align_post import ( + HalfProjAlignPostMg, + FullProjAlignPostMg, + HalfProjAlignPost, + FullProjAlignPost, +) +from brainpy._src.dyn.projections.align_pre import ( + FullProjAlignPreSDMg, + FullProjAlignPreDSMg, + FullProjAlignPreSD, + FullProjAlignPreDS, ) - from brainpy._src.dyn.projections.conn import ( SynConn as SynConn, ) - -from brainpy._src.dyn.projections.others import ( - PoissonInput as PoissonInput, -) - from brainpy._src.dyn.projections.inputs import ( InputVar, + PoissonInput, ) diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py index 68be31944..9a097be1a 100644 --- a/brainpy/dyn/synapses.py +++ b/brainpy/dyn/synapses.py @@ -1,6 +1,5 @@ from brainpy._src.dyn.synapses.abstract_models import ( - Delta, Expon, Alpha, DualExpon, diff --git a/docs/apis/brainpy.dyn.projections.rst b/docs/apis/brainpy.dyn.projections.rst index c1f8c1070..5549e6394 100644 --- a/docs/apis/brainpy.dyn.projections.rst +++ b/docs/apis/brainpy.dyn.projections.rst @@ -6,27 +6,23 @@ Synaptic Projections -Reduced Projections -------------------- +Projections for Align-Post Reduction +------------------------------------ .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - ProjAlignPostMg1 - ProjAlignPostMg2 - ProjAlignPost1 - ProjAlignPost2 - ProjAlignPreMg1 - ProjAlignPreMg2 - ProjAlignPre1 - ProjAlignPre2 + HalfProjAlignPostMg + FullProjAlignPostMg + HalfProjAlignPost + FullProjAlignPost -Projections ------------ +Projections for Align-Pre Reduction +------------------------------------ .. autosummary:: :toctree: generated/ @@ -34,7 +30,23 @@ Projections :template: classtemplate.rst VanillaProj - SynConn + FullProjAlignPreSDMg + FullProjAlignPreDSMg + FullProjAlignPreSD + FullProjAlignPreDS + + + +Projections for Delta synapses +------------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + HalfProjDelta + FullProjDelta @@ -46,6 +58,18 @@ Inputs :nosignatures: :template: classtemplate.rst - PoissonInput InputVar + + + +Others +------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + SynConn + diff --git a/docs/apis/brainpy.dyn.synapses.rst b/docs/apis/brainpy.dyn.synapses.rst index ea4313c69..bea61ab87 100644 --- a/docs/apis/brainpy.dyn.synapses.rst +++ b/docs/apis/brainpy.dyn.synapses.rst @@ -42,7 +42,6 @@ Phenomenological synapse models :nosignatures: :template: classtemplate.rst - Delta Expon Alpha DualExpon diff --git a/docs/apis/losses.rst b/docs/apis/losses.rst index 8f50c487f..4f4a3d167 100644 --- a/docs/apis/losses.rst +++ b/docs/apis/losses.rst @@ -33,6 +33,14 @@ Comparison log_cosh_loss ctc_loss_with_forward_probs ctc_loss + multi_margin_loss + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + CrossEntropyLoss NLLLoss L1Loss diff --git a/docs/tutorial_FAQs/brainpy_ecosystem.ipynb b/docs/tutorial_FAQs/brainpy_ecosystem.ipynb index ed88c9596..4b28375b5 100644 --- a/docs/tutorial_FAQs/brainpy_ecosystem.ipynb +++ b/docs/tutorial_FAQs/brainpy_ecosystem.ipynb @@ -51,6 +51,35 @@ "\n", "[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale) provides one solution for large-scale modeling. It enables multi-device running for BrainPy models.\n" ] + }, + { + "cell_type": "markdown", + "source": [ + "## 《神经计算建模实战》\n", + "\n", + "[《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling) is a book for brain dynamics modeling based on BrainPy. It introduces the basic concepts and methods of brain dynamics modeling, and provides comprehensive examples for brain dynamics modeling with BrainPy. \n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## 神经计算建模与编程培训班\n", + "\n", + "There is a series of training courses for brain dynamics modeling based on BrainPy. \n", + "\n", + "- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) \n", + "\n", + "- [第二届神经计算建模与编程培训班 (Second Training Course on Neural Modeling and Programming)](https://github.com/brainpy/2nd-neural-modeling-and-programming-course)\n", + "\n", + "This course is based on the textbook [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling), supplemented by BrainPy, and based on the theory of \"theory+practice\" combination of teaching and learning. Through this course, students will master the basic concepts, methods and techniques of neural computation modelling, as well as how to use Python programming language to achieve convenient modelling and efficient simulation of neural systems, laying a solid foundation for future research in the field of neural computation or in the field of brain-like intelligence.\n", + "\n" + ], + "metadata": { + "collapsed": false + } } ], "metadata": { diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py index af7511e19..60b325657 100644 --- a/examples/dynamics_simulation/COBA.py +++ b/examples/dynamics_simulation/COBA.py @@ -13,7 +13,7 @@ def __init__(self, num_exc, num_inh, inp=20.): self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars) self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars) - self.E2I = bp.dyn.ProjAlignPreMg1( + self.E2I = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=None, @@ -21,7 +21,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=0.), post=self.I, ) - self.E2E = bp.dyn.ProjAlignPreMg1( + self.E2E = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=None, @@ -29,7 +29,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=0.), post=self.E, ) - self.I2E = bp.dyn.ProjAlignPreMg1( + self.I2E = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=None, @@ -37,7 +37,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPreMg1( + self.I2I = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=0., @@ -67,7 +67,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): self.E = bp.dyn.LifRef(num_exc, **neu_pars) self.I = bp.dyn.LifRef(num_inh, **neu_pars) - self.E2E = bp.dyn.ProjAlignPostMg2( + self.E2E = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.E.num), 0.6), @@ -75,7 +75,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=0.), post=self.E, ) - self.E2I = bp.dyn.ProjAlignPostMg2( + self.E2I = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.I.num), 0.6), @@ -83,7 +83,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=0.), post=self.I, ) - self.I2E = bp.dyn.ProjAlignPostMg2( + self.I2E = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.E.num), 6.7), @@ -91,7 +91,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPostMg2( + self.I2I = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.I.num), 6.7), diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py index 45cf81953..954b01734 100644 --- a/examples/dynamics_simulation/COBA_parallel.py +++ b/examples/dynamics_simulation/COBA_parallel.py @@ -11,7 +11,7 @@ class ExpJIT(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=bp.dnn.EventJitFPHomoLinear(pre_num, post.num, prob=prob, weight=g_max), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), out=bp.dyn.COBA.desc(E=E), @@ -40,7 +40,7 @@ def update(self, input): class ExpMasked(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, sharding=[None, bm.sharding.NEU_AXIS]), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), @@ -111,7 +111,7 @@ def _f(self, indices, indptr, x): class ExpMasked2(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=PCSR(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, num_shard=4), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), out=bp.dyn.COBA.desc(E=E), diff --git a/examples/dynamics_simulation/decision_making_network.py b/examples/dynamics_simulation/decision_making_network.py index 5351680e6..334f99712 100644 --- a/examples/dynamics_simulation/decision_making_network.py +++ b/examples/dynamics_simulation/decision_making_network.py @@ -18,7 +18,7 @@ def __init__(self, pre, post, conn, delay, g_max, tau, E): raise ValueError syn = bp.dyn.Expon.desc(post.num, tau=tau) out = bp.dyn.COBA.desc(E=E) - self.proj = bp.dyn.ProjAlignPostMg2( + self.proj = bp.dyn.FullProjAlignPostMg( pre=pre, delay=delay, comm=comm, syn=syn, out=out, post=post ) @@ -35,7 +35,7 @@ def __init__(self, pre, post, conn, delay, g_max): raise ValueError syn = bp.dyn.NMDA.desc(pre.num, a=0.5, tau_decay=100., tau_rise=2.) out = bp.dyn.MgBlock(E=0., cc_Mg=1.0) - self.proj = bp.dyn.ProjAlignPreMg2( + self.proj = bp.dyn.FullProjAlignPreDSMg( pre=pre, delay=delay, syn=syn, comm=comm, out=out, post=post ) diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 2243a9ca1..f98527458 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -9,14 +9,14 @@ def __init__(self): self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=4000, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=4000, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) def update(self, input): spk = self.delay.at('I') @@ -40,30 +40,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -118,30 +118,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -167,30 +167,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E()