Loading...

TODO

  • 1.psn 为啥没有reset
  • 2.psn的缺点
  • 3.代码跑起来
  • 4.以及探索一下为啥t b d不收敛,尝试修改,纠正代码的bug

PSN代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
"""
:param T: the number of time-steps
:type T: int
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

The Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as

.. math::

H &= WX, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\label{eq psn neuronal charge}\\\\
S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}

where :math:`W` is the learnable weight matrix, and :math:`B` is the learnable threshold.

.. admonition:: Note
:class: note

The PSN only supports the multi-step mode.
"""
super().__init__()
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

def forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq)
return spike_seq.view(x_seq.shape)

def extra_repr(self):
return super().extra_repr() + f', T={self.T}'