Skip to content

Commit 0330504

Browse files
author
Alexander Ororbia
committed
refactored/ported RAFCell to v3
1 parent 92b6940 commit 0330504

File tree

1 file changed

+50
-45
lines changed

1 file changed

+50
-45
lines changed

ngclearn/components/neurons/spiking/RAFCell.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
88
step_euler, step_rk2
99

10-
from ngcsimlib.compilers.process import transition
11-
#from ngcsimlib.component import Component
10+
from ngcsimlib.parser import compilable
1211
from ngcsimlib.compartment import Compartment
1312

13+
########################################################################################################################
14+
## RAF dynamics (multi-dimensional ODEs)
1415
@jit
1516
def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics
1617
# dy/dt = omega x + b y
@@ -34,6 +35,7 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper
3435
j, v, tau_w, omega, b = params
3536
dv_dt = _dfw_internal(j, v, w, tau_w, omega, b)
3637
return dv_dt
38+
########################################################################################################################
3739

3840
class RAFCell(JaxComponent):
3941
"""
@@ -60,8 +62,7 @@ class RAFCell(JaxComponent):
6062
| tols - time-of-last-spike
6163
6264
| References:
63-
| Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks
64-
| 14.6-7 (2001): 883-894.
65+
| Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks 14.6-7 (2001): 883-894.
6566
6667
Args:
6768
name: the string name of this cell
@@ -77,7 +78,7 @@ class RAFCell(JaxComponent):
7778
7879
omega: angular frequency (Default: 10)
7980
80-
b: oscillation dampening factor (Default: -1)
81+
dampen_factor: oscillation dampening factor (Default: -1) ("b" in Izhikevich 2001)
8182
8283
v_reset: reset condition for membrane potential (Default: 1 mV)
8384
@@ -98,10 +99,10 @@ class RAFCell(JaxComponent):
9899
at an increase in computational cost (and simulation time)
99100
"""
100101

101-
@deprecate_args(resist_m="resist_v", tau_m="tau_v")
102+
@deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor")
102103
def __init__(
103-
self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., b=-1., v_reset=0., w_reset=0., v0=0., w0=0.,
104-
resist_v=1., integration_type="euler", batch_size=1, **kwargs
104+
self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., dampen_factor=-1., v_reset=0., w_reset=0.,
105+
v0=0., w0=0., resist_v=1., integration_type="euler", batch_size=1, **kwargs
105106
):
106107
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
107108
super().__init__(name, **kwargs)
@@ -115,8 +116,8 @@ def __init__(
115116
self.resist_v = resist_v
116117
self.tau_w = tau_w
117118
self.omega = omega ## angular frequency
118-
self.b = b ## dampening factor
119-
## note: the smaller b is, the faster the oscillation dampens to resting state values
119+
self.dampen_factor = dampen_factor ## dampening factor (b)
120+
## Note: the smaller that dampen_factor "b" is, the faster the oscillation dampens to resting state values
120121
self.v_reset = v_reset
121122
self.w_reset = w_reset
122123
self.v0 = v0
@@ -137,42 +138,46 @@ def __init__(
137138
restVals, display_name="Time-of-Last-Spike", units="ms"
138139
) ## time-of-last-spike
139140

140-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
141-
@staticmethod
142-
def advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
143-
v_reset, w_reset, intgFlag, j, v, w, tols):
141+
@compilable
142+
def advance_state(
143+
self, t, dt
144+
):
144145
## continue with centered dynamics
145-
j_ = j * resist_v
146-
if intgFlag == 1: ## RK-2/midpoint
146+
j_ = self.j.get() * self.resist_v
147+
if self.intgFlag == 1: ## RK-2/midpoint
147148
## Note: we integrate ODEs in order: first w, then v
148-
w_params = (j_, v, tau_w, omega, b)
149-
_, _w = step_rk2(0., w, _dfw, dt, w_params)
150-
v_params = (j_, _w, tau_v, omega, b)
151-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
149+
w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor)
150+
_, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params)
151+
v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor)
152+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
152153
else: # integType == 0 (default -- Euler)
153154
## Note: we integrate ODEs in order: first w, then v
154-
w_params = (j_, v, tau_w, omega, b)
155-
_, _w = step_euler(0., w, _dfw, dt, w_params)
156-
v_params = (j_, _w, tau_v, omega, b)
157-
_, _v = step_euler(0., v, _dfv, dt, v_params)
158-
s = (_v > thr) * 1. ## emit spikes/pulses
155+
w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor)
156+
_, _w = step_euler(0., self.w.get(), _dfw, dt, w_params)
157+
v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor)
158+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
159+
160+
s = (_v > self.thr) * 1. ## emit spikes/pulses
159161
## hyperpolarize/reset/snap variables
160-
w = _w * (1. - s) + s * w_reset
161-
v = _v * (1. - s) + s * v_reset
162-
163-
tols = (1. - s) * tols + (s * t) ## update times-of-last-spike(s)
164-
return j, v, w, s, tols
165-
166-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
167-
@staticmethod
168-
def reset(batch_size, n_units, v0, w0):
169-
restVals = jnp.zeros((batch_size, n_units))
170-
j = restVals # None
171-
v = restVals + v0
172-
w = restVals + w0
173-
s = restVals #+ 0
174-
tols = restVals #+ 0
175-
return j, v, w, s, tols
162+
w = _w * (1. - s) + s * self.w_reset
163+
v = _v * (1. - s) + s * self.v_reset
164+
165+
self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s)
166+
167+
#self.j.set(j_)
168+
self.v.set(v)
169+
self.w.set(w)
170+
self.s.set(s)
171+
172+
@compilable
173+
def reset(self):
174+
restVals = jnp.zeros((self.batch_size, self.n_units))
175+
if not self.j.targeted:
176+
self.j.set(restVals)
177+
self.v.set(restVals + self.v0)
178+
self.w.set(restVals + self.w0)
179+
self.s.set(restVals)
180+
self.tols.set(restVals)
176181

177182
@classmethod
178183
def help(cls): ## component help function
@@ -198,7 +203,7 @@ def help(cls): ## component help function
198203
"tau_w": "Recovery variable time constant",
199204
"v_reset": "Reset membrane potential value",
200205
"w_reset": "Reset angular driver value",
201-
"b": "Exponential dampening factor applied to oscillations",
206+
"dampen_factor": "Exponential dampening factor applied to oscillations",
202207
"omega": "Angular frequency of neuronal progress per second (radians)",
203208
"v0": "Initial condition for membrane potential/voltage",
204209
"w0": "Initial condition for membrane angular driver variable",
@@ -207,8 +212,8 @@ def help(cls): ## component help function
207212
}
208213
info = {cls.__name__: properties,
209214
"compartments": compartment_props,
210-
"dynamics": "tau_v * dv/dt = omega * w + v * b; "
211-
"tau_w * dw/dt = w * b - v * omega + j",
215+
"dynamics": "tau_v * dv/dt = omega * w + v * dampen_factor; "
216+
"tau_w * dw/dt = w * dampen_factor - v * omega + j",
212217
"hyperparameters": hyperparams}
213218
return info
214219

0 commit comments

Comments
 (0)