22from jax import numpy as jnp , random , jit , nn
33from functools import partial
44from ngclearn .utils import tensorstats
5- from ngcsimlib . deprecators import deprecate_args
5+ from ngcsimlib import deprecate_args
66from ngcsimlib .logger import info , warn
77from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
88 step_euler , step_rk2
99
10- from ngcsimlib .compilers .process import transition
11- #from ngcsimlib.component import Component
10+ from ngcsimlib .parser import compilable
1211from ngcsimlib .compartment import Compartment
1312
13+ ########################################################################################################################
14+ ## RAF dynamics (multi-dimensional ODEs)
1415@jit
1516def _dfv_internal (j , v , w , tau_m , omega , b ): ## "voltage" dynamics
1617 # dy/dt = omega x + b y
@@ -34,6 +35,7 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper
3435 j , v , tau_w , omega , b = params
3536 dv_dt = _dfw_internal (j , v , w , tau_w , omega , b )
3637 return dv_dt
38+ ########################################################################################################################
3739
3840class RAFCell (JaxComponent ):
3941 """
@@ -60,8 +62,7 @@ class RAFCell(JaxComponent):
6062 | tols - time-of-last-spike
6163
6264 | References:
63- | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks
64- | 14.6-7 (2001): 883-894.
65+ | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks 14.6-7 (2001): 883-894.
6566
6667 Args:
6768 name: the string name of this cell
@@ -77,7 +78,7 @@ class RAFCell(JaxComponent):
7778
7879 omega: angular frequency (Default: 10)
7980
80- b : oscillation dampening factor (Default: -1)
81+ dampen_factor : oscillation dampening factor (Default: -1) ("b" in Izhikevich 2001 )
8182
8283 v_reset: reset condition for membrane potential (Default: 1 mV)
8384
@@ -98,10 +99,10 @@ class RAFCell(JaxComponent):
9899 at an increase in computational cost (and simulation time)
99100 """
100101
101- @deprecate_args (resist_m = "resist_v" , tau_m = "tau_v" )
102+ @deprecate_args (resist_m = "resist_v" , tau_m = "tau_v" , b = "dampen_factor" )
102103 def __init__ (
103- self , name , n_units , tau_v = 1. , tau_w = 1. , thr = 1. , omega = 10. , b = - 1. , v_reset = 0. , w_reset = 0. , v0 = 0. , w0 = 0. ,
104- resist_v = 1. , integration_type = "euler" , batch_size = 1 , ** kwargs
104+ self , name , n_units , tau_v = 1. , tau_w = 1. , thr = 1. , omega = 10. , dampen_factor = - 1. , v_reset = 0. , w_reset = 0. ,
105+ v0 = 0. , w0 = 0. , resist_v = 1. , integration_type = "euler" , batch_size = 1 , ** kwargs
105106 ):
106107 #v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
107108 super ().__init__ (name , ** kwargs )
@@ -115,8 +116,8 @@ def __init__(
115116 self .resist_v = resist_v
116117 self .tau_w = tau_w
117118 self .omega = omega ## angular frequency
118- self .b = b ## dampening factor
119- ## note : the smaller b is, the faster the oscillation dampens to resting state values
119+ self .dampen_factor = dampen_factor ## dampening factor (b)
120+ ## Note : the smaller that dampen_factor "b" is, the faster the oscillation dampens to resting state values
120121 self .v_reset = v_reset
121122 self .w_reset = w_reset
122123 self .v0 = v0
@@ -137,42 +138,46 @@ def __init__(
137138 restVals , display_name = "Time-of-Last-Spike" , units = "ms"
138139 ) ## time-of-last-spike
139140
140- @transition ( output_compartments = [ "j" , "v" , "w" , "s" , "tols" ])
141- @ staticmethod
142- def advance_state ( t , dt , tau_v , resist_v , tau_w , thr , omega , b ,
143- v_reset , w_reset , intgFlag , j , v , w , tols ):
141+ @compilable
142+ def advance_state (
143+ self , t , dt
144+ ):
144145 ## continue with centered dynamics
145- j_ = j * resist_v
146- if intgFlag == 1 : ## RK-2/midpoint
146+ j_ = self . j . get () * self . resist_v
147+ if self . intgFlag == 1 : ## RK-2/midpoint
147148 ## Note: we integrate ODEs in order: first w, then v
148- w_params = (j_ , v , tau_w , omega , b )
149- _ , _w = step_rk2 (0. , w , _dfw , dt , w_params )
150- v_params = (j_ , _w , tau_v , omega , b )
151- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
149+ w_params = (j_ , self . v . get (), self . tau_w , self . omega , self . dampen_factor )
150+ _ , _w = step_rk2 (0. , self . w . get () , _dfw , dt , w_params )
151+ v_params = (j_ , _w , self . tau_v , self . omega , self . dampen_factor )
152+ _ , _v = step_rk2 (0. , self . v . get () , _dfv , dt , v_params )
152153 else : # integType == 0 (default -- Euler)
153154 ## Note: we integrate ODEs in order: first w, then v
154- w_params = (j_ , v , tau_w , omega , b )
155- _ , _w = step_euler (0. , w , _dfw , dt , w_params )
156- v_params = (j_ , _w , tau_v , omega , b )
157- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
158- s = (_v > thr ) * 1. ## emit spikes/pulses
155+ w_params = (j_ , self .v .get (), self .tau_w , self .omega , self .dampen_factor )
156+ _ , _w = step_euler (0. , self .w .get (), _dfw , dt , w_params )
157+ v_params = (j_ , _w , self .tau_v , self .omega , self .dampen_factor )
158+ _ , _v = step_euler (0. , self .v .get (), _dfv , dt , v_params )
159+
160+ s = (_v > self .thr ) * 1. ## emit spikes/pulses
159161 ## hyperpolarize/reset/snap variables
160- w = _w * (1. - s ) + s * w_reset
161- v = _v * (1. - s ) + s * v_reset
162-
163- tols = (1. - s ) * tols + (s * t ) ## update times-of-last-spike(s)
164- return j , v , w , s , tols
165-
166- @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
167- @staticmethod
168- def reset (batch_size , n_units , v0 , w0 ):
169- restVals = jnp .zeros ((batch_size , n_units ))
170- j = restVals # None
171- v = restVals + v0
172- w = restVals + w0
173- s = restVals #+ 0
174- tols = restVals #+ 0
175- return j , v , w , s , tols
162+ w = _w * (1. - s ) + s * self .w_reset
163+ v = _v * (1. - s ) + s * self .v_reset
164+
165+ self .tols .set ((1. - s ) * self .tols .get () + (s * t )) ## update times-of-last-spike(s)
166+
167+ #self.j.set(j_)
168+ self .v .set (v )
169+ self .w .set (w )
170+ self .s .set (s )
171+
172+ @compilable
173+ def reset (self ):
174+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
175+ if not self .j .targeted :
176+ self .j .set (restVals )
177+ self .v .set (restVals + self .v0 )
178+ self .w .set (restVals + self .w0 )
179+ self .s .set (restVals )
180+ self .tols .set (restVals )
176181
177182 @classmethod
178183 def help (cls ): ## component help function
@@ -198,7 +203,7 @@ def help(cls): ## component help function
198203 "tau_w" : "Recovery variable time constant" ,
199204 "v_reset" : "Reset membrane potential value" ,
200205 "w_reset" : "Reset angular driver value" ,
201- "b " : "Exponential dampening factor applied to oscillations" ,
206+ "dampen_factor " : "Exponential dampening factor applied to oscillations" ,
202207 "omega" : "Angular frequency of neuronal progress per second (radians)" ,
203208 "v0" : "Initial condition for membrane potential/voltage" ,
204209 "w0" : "Initial condition for membrane angular driver variable" ,
@@ -207,8 +212,8 @@ def help(cls): ## component help function
207212 }
208213 info = {cls .__name__ : properties ,
209214 "compartments" : compartment_props ,
210- "dynamics" : "tau_v * dv/dt = omega * w + v * b ; "
211- "tau_w * dw/dt = w * b - v * omega + j" ,
215+ "dynamics" : "tau_v * dv/dt = omega * w + v * dampen_factor ; "
216+ "tau_w * dw/dt = w * dampen_factor - v * omega + j" ,
212217 "hyperparameters" : hyperparams }
213218 return info
214219
0 commit comments