1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| class PSN(nn.Module, base.MultiStepModule): def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()): """ :param T: the number of time-steps :type T: int :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable
The Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as
.. math::
H &= WX, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\label{eq psn neuronal charge}\\\\ S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}
where :math:`W` is the learnable weight matrix, and :math:`B` is the learnable threshold.
.. admonition:: Note :class: note
The PSN only supports the multi-step mode. """ super().__init__() self.T = T self.surrogate_function = surrogate_function weight = torch.zeros([T, T]) bias = torch.zeros([T, 1])
self.weight = nn.Parameter(weight) self.bias = nn.Parameter(bias)
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.constant_(self.bias, -1.)
def forward(self, x_seq: torch.Tensor): h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1)) spike_seq = self.surrogate_function(h_seq) return spike_seq.view(x_seq.shape)
def extra_repr(self): return super().extra_repr() + f', T={self.T}'
|