Loading...

基本概念

1. 激活值的表示方法

spikingjelly.activation_based 使用取值仅为0或1的张量表示脉冲,例如:

1
2
3
4
5
6
7
import torch

v = torch.rand([8])
v_th = 0.5
spike = (v >= v_th).to(v)
print('spike =', spike)
# spike = tensor([0., 0., 0., 1., 1., 0., 1., 0.])

输出结果:

1
2
3
v = tensor([0.8156, 0.7492, 0.0531, 0.7591, 0.4431, 0.7992, 0.8907, 0.2421])

spike = tensor([1., 1., 0., 1., 0., 1., 1., 0.])

2. 数据格式

spikingjelly.activation_based 中,数据有两种格式,分别为:

  • 表示单个时刻的数据,其 shape = [N, *],其中 N 是batch维度,* 表示任意额外的维度

  • 表示多个时刻的数据,其 shape = [T, N, *],其中 T 是数据的时间维度, N 是batch维度,* 表示任意额外的维度


3. 步进模式

spikingjelly.activation_based 中的模块,具有两种传播模式,分别是

  • 单步模式(single-step):数据使用 shape = [N, *] 的格式
  • 多步模式(multi-step):数据使用 shape = [T, N, *] 的格式

模块在初始化时可以指定其使用的步进模式 step_mode,也可以在构建后直接进行修改:

1
2
3
4
5
6
7
import torch
from spikingjelly.activation_based import neuron

net = neuron.IFNode(step_mode='m')
# 'm' is the multi-step mode
net.step_mode = 's'
# 's' is the single-step mode

如果我们想给单步模式的模块输入 shape = [T, N, *] 的序列数据,通常需要手动做一个时间上的循环,将数据拆成 Tshape = [N, *] 的数据并逐步输入进去。让我们新建一层IF神经元,设置为单步模式,将数据逐步输入并得到输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from spikingjelly.activation_based import neuron

net_s = neuron.IFNode(step_mode='s')
T = 3 #时间步
N = 1 #batchsize
C = 3 #通道数
H = 8 #数据的高
W = 8 #数据的宽
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
for t in range(T):
x = x_seq[t] # x.shape = [N, C, H, W]
y = net_s(x) # y.shape = [N, C, H, W]
y_seq.append(y.unsqueeze(0)) #将y增加一个维度 添加到列表y_seq中

y_seq = torch.cat(y_seq) #将所有输出张量沿着新的时间维度拼起来,形成新的张量
# y_seq.shape = [T, N, C, H, W]

输出结果:

x序列初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
x_seq= tensor([[[[[2.5490e-01, 1.2639e-01, 5.7598e-01, 9.8435e-01, 2.7988e-01,
8.4105e-01, 5.0002e-01, 7.9076e-01],
[7.1258e-01, 4.4736e-01, 7.8368e-01, 8.3618e-01, 4.0139e-02,
9.7412e-02, 6.7998e-01, 7.1286e-01],
[1.8884e-01, 2.3037e-01, 2.9191e-01, 1.7953e-01, 9.2891e-01,
4.4400e-01, 6.5340e-01, 1.5981e-01],
[2.4112e-01, 4.1668e-01, 6.0014e-01, 4.9383e-01, 3.4121e-01,
2.1231e-01, 7.0532e-01, 6.7256e-01],
[1.9463e-02, 4.5558e-01, 7.2580e-01, 7.8367e-01, 6.1989e-01,
1.6519e-02, 9.6447e-01, 2.3580e-01],
[6.2598e-01, 2.1307e-01, 6.3583e-01, 6.9059e-01, 1.4211e-01,
9.7064e-02, 3.7473e-01, 5.1632e-01],
[6.9992e-01, 7.6289e-01, 3.5203e-01, 4.8954e-01, 8.7898e-01,
7.8275e-03, 5.2165e-01, 7.9674e-01],
[6.0423e-01, 5.1216e-01, 2.0406e-01, 2.4626e-01, 8.3361e-01,
3.1688e-01, 6.5840e-01, 1.3557e-01]],

[[4.7398e-01, 3.7301e-01, 1.1466e-01, 1.8241e-01, 4.3652e-01,
1.6438e-01, 1.9071e-01, 8.9192e-01],
[6.8429e-01, 8.9815e-01, 1.0108e-01, 8.0698e-02, 8.9986e-01,
2.6142e-01, 3.6577e-01, 8.3392e-01],
[5.6779e-01, 1.4391e-01, 3.8700e-01, 5.3148e-01, 8.7907e-01,
3.8305e-02, 2.6470e-01, 4.1808e-01],
[7.0598e-01, 5.7686e-01, 6.4152e-01, 6.2110e-01, 2.7309e-01,
6.1881e-01, 9.8665e-01, 7.9703e-01],
[3.3107e-01, 2.7862e-01, 9.9426e-01, 4.6849e-01, 9.5992e-01,
8.4558e-01, 7.9661e-02, 4.1134e-01],
[5.4624e-02, 9.5138e-01, 1.4058e-01, 9.6501e-01, 3.5145e-01,
2.8402e-01, 3.8676e-01, 6.5260e-01],
[5.2004e-01, 4.7728e-01, 5.5676e-01, 3.6984e-01, 9.6766e-01,
9.6658e-01, 3.1954e-01, 4.4453e-01],
[5.3666e-02, 2.9791e-01, 8.0187e-01, 5.4459e-01, 4.6331e-01,
6.9315e-01, 9.1670e-01, 4.9692e-01]],

[[4.2555e-01, 2.8505e-01, 3.2860e-01, 9.9233e-01, 9.3850e-01,
8.5688e-01, 1.5506e-01, 2.5449e-01],
[9.6989e-01, 5.2753e-01, 7.7811e-01, 5.6960e-01, 3.8716e-01,
7.5363e-01, 9.6678e-01, 3.5747e-01],
[9.7789e-01, 3.2894e-01, 6.8379e-01, 6.4656e-01, 1.5041e-01,
1.1086e-01, 4.0547e-01, 1.8388e-01],
[1.6842e-01, 8.9246e-01, 3.5892e-01, 6.2443e-01, 5.6638e-03,
9.1472e-01, 9.5463e-01, 2.0740e-01],
[5.1649e-01, 3.5757e-01, 1.4834e-01, 5.3815e-01, 4.6948e-01,
8.8748e-01, 6.6702e-01, 4.2158e-01],
[1.2418e-02, 4.4521e-01, 8.2854e-01, 1.9373e-01, 6.4802e-01,
1.6356e-01, 4.1023e-01, 4.1840e-02],
[1.7540e-01, 8.9243e-02, 5.9020e-01, 9.4828e-01, 4.3418e-01,
8.9228e-01, 8.6908e-01, 6.0948e-01],
[3.2250e-01, 1.2573e-01, 4.3231e-01, 8.6035e-01, 2.8213e-01,
1.6248e-04, 9.3926e-01, 6.9541e-02]]]],



[[[[2.3089e-01, 9.8811e-01, 1.8840e-01, 2.8811e-01, 8.0925e-02,
4.9717e-01, 4.4312e-01, 4.5832e-01],
[1.0187e-01, 1.3159e-01, 4.3373e-01, 9.1466e-01, 4.2633e-01,
6.3525e-01, 8.4113e-02, 9.2720e-01],
[2.5519e-01, 4.4282e-01, 5.0446e-01, 8.4111e-01, 9.7644e-01,
2.7621e-01, 7.2177e-02, 9.2591e-01],
[2.4659e-01, 1.5083e-01, 6.7537e-01, 2.8056e-01, 7.1648e-01,
3.3301e-01, 5.1138e-01, 9.6654e-01],
[2.9088e-01, 1.9395e-01, 7.9314e-01, 9.3603e-01, 6.4237e-01,
9.2068e-01, 8.0904e-01, 5.4206e-01],
[9.1367e-02, 4.5597e-01, 1.8322e-01, 1.3655e-02, 2.0238e-01,
3.5409e-01, 5.3555e-01, 6.0905e-01],
[1.6518e-01, 8.0658e-01, 3.7144e-01, 9.4313e-01, 9.2144e-01,
6.8941e-01, 5.7647e-01, 9.3552e-01],
[2.6781e-01, 4.3228e-01, 3.4587e-01, 8.9714e-01, 7.8391e-01,
7.1665e-01, 1.5365e-01, 6.4215e-01]],

[[8.9499e-01, 5.7280e-01, 3.0730e-01, 1.2089e-01, 7.3408e-02,
1.5600e-01, 7.2978e-01, 4.1706e-01],
[7.0067e-01, 8.3507e-02, 3.5878e-01, 4.5111e-01, 5.2309e-01,
1.0923e-01, 6.2161e-01, 8.7058e-01],
[3.4376e-01, 5.7006e-01, 7.8923e-01, 9.6056e-01, 3.4098e-01,
7.4725e-01, 9.9692e-01, 7.1960e-01],
[7.7939e-02, 2.1983e-01, 5.1086e-01, 7.4400e-01, 8.4883e-01,
6.2231e-02, 7.6414e-01, 1.5084e-01],
[2.7195e-01, 7.1122e-01, 3.5306e-01, 2.7154e-01, 9.4240e-01,
3.4124e-01, 9.9192e-01, 8.1886e-01],
[8.2985e-01, 5.0439e-01, 8.2339e-01, 5.8634e-01, 6.2499e-01,
1.7665e-02, 5.8843e-01, 3.1627e-01],
[2.5274e-01, 9.9553e-01, 2.0932e-01, 9.2049e-01, 8.8624e-03,
5.4530e-01, 8.1870e-01, 4.3063e-01],
[3.0285e-01, 7.4531e-01, 4.7332e-01, 5.4647e-01, 9.7685e-01,
1.0625e-01, 4.3261e-01, 3.7397e-01]],

[[8.4365e-02, 5.5545e-01, 2.8297e-01, 4.0169e-01, 7.2488e-01,
4.6406e-01, 6.1969e-01, 6.0868e-01],
[8.8171e-01, 4.7924e-01, 8.2435e-01, 2.6328e-01, 8.0640e-01,
1.2693e-01, 3.8701e-01, 9.9115e-01],
[2.6400e-01, 1.8561e-01, 5.1274e-01, 7.3849e-01, 3.6732e-01,
2.7895e-02, 6.9665e-01, 9.0970e-01],
[6.7955e-01, 3.0955e-01, 3.2706e-01, 7.1760e-01, 7.4908e-01,
9.5461e-01, 6.3022e-01, 7.1336e-01],
[6.3269e-01, 8.7859e-01, 5.2982e-01, 9.9928e-01, 5.9146e-01,
6.0319e-01, 5.5390e-01, 1.2463e-01],
[9.3675e-01, 6.5130e-01, 2.5262e-01, 5.8463e-01, 8.3454e-01,
8.4987e-01, 7.4374e-01, 6.9787e-01],
[4.3745e-01, 1.3517e-01, 3.3122e-01, 9.1682e-01, 7.2772e-01,
6.8150e-01, 5.7321e-01, 9.3979e-01],
[5.7491e-01, 7.5739e-01, 8.3784e-02, 3.8622e-01, 5.8774e-02,
2.2148e-01, 6.8164e-01, 2.5192e-03]]]],



[[[[8.8257e-01, 4.1188e-01, 6.5191e-01, 2.5938e-01, 5.5664e-01,
8.5302e-01, 5.0620e-01, 3.8792e-01],
[1.6840e-01, 8.3896e-01, 5.1250e-01, 2.5363e-01, 9.8131e-01,
7.7844e-01, 6.9882e-01, 2.7488e-01],
[5.9363e-01, 7.3822e-01, 8.7725e-01, 8.7122e-02, 9.7863e-01,
2.2889e-02, 3.8829e-01, 4.9083e-01],
[6.3988e-01, 3.6111e-01, 8.6678e-01, 7.2978e-01, 2.7875e-01,
9.0873e-02, 1.2682e-01, 8.5834e-02],
[7.2737e-01, 5.7738e-01, 4.3050e-01, 3.7719e-01, 7.7123e-01,
8.9869e-01, 5.3084e-01, 6.6507e-01],
[6.5190e-01, 3.9860e-01, 9.9952e-01, 1.2379e-01, 8.0746e-01,
9.7127e-01, 7.3734e-01, 7.4749e-01],
[2.1485e-01, 4.7792e-01, 2.8937e-01, 4.1115e-01, 6.8940e-01,
9.0786e-01, 8.9880e-01, 1.4463e-01],
[7.0096e-01, 1.3410e-01, 1.6139e-01, 6.7555e-01, 3.6055e-01,
1.1819e-01, 4.5473e-02, 5.0582e-01]],

[[9.2175e-01, 1.3608e-01, 3.0082e-01, 6.9805e-01, 6.6508e-01,
5.9953e-01, 2.3041e-01, 8.6712e-01],
[1.2560e-01, 7.5882e-01, 8.8983e-01, 2.9503e-01, 5.5485e-01,
8.0917e-01, 5.3546e-01, 3.6151e-01],
[5.7187e-01, 7.3159e-02, 9.6407e-02, 9.0816e-01, 4.0066e-02,
7.9157e-01, 9.9889e-01, 7.8920e-01],
[3.7230e-01, 3.7325e-01, 4.4339e-01, 3.4131e-01, 7.2465e-01,
4.7837e-01, 4.3739e-01, 8.3011e-01],
[2.0016e-01, 6.5061e-01, 5.3490e-01, 4.0561e-01, 9.8350e-01,
5.5320e-01, 7.0122e-01, 5.4667e-02],
[1.5879e-01, 5.4470e-01, 7.9612e-01, 9.4982e-01, 7.3154e-01,
6.7719e-01, 4.0903e-01, 1.1639e-01],
[1.4119e-01, 6.6653e-01, 2.5942e-01, 7.6776e-01, 1.3594e-01,
9.8087e-01, 7.9487e-01, 2.7664e-02],
[5.3392e-01, 4.4390e-01, 1.0553e-01, 9.9423e-01, 9.7157e-01,
1.5699e-01, 9.2323e-01, 8.4864e-01]],

[[9.0617e-01, 1.5988e-01, 2.6798e-01, 7.0112e-01, 1.1410e-03,
8.4951e-02, 8.7687e-01, 3.3408e-01],
[7.5975e-01, 4.8089e-01, 3.0892e-01, 6.6032e-01, 8.7860e-01,
3.7284e-01, 1.6264e-01, 5.8246e-01],
[3.9431e-02, 1.2114e-01, 7.6150e-01, 5.9015e-01, 2.1748e-01,
7.6507e-01, 6.4522e-02, 1.1229e-01],
[7.7141e-01, 4.4919e-01, 5.0839e-01, 2.5215e-01, 5.7316e-01,
5.6246e-01, 7.1742e-01, 8.5685e-01],
[3.9496e-01, 1.4702e-01, 9.2753e-01, 9.3801e-01, 7.0536e-01,
4.4737e-01, 6.0602e-01, 4.4963e-01],
[2.0537e-01, 4.4529e-01, 9.8825e-01, 6.4068e-01, 6.6816e-01,
8.5496e-01, 1.4431e-01, 9.1136e-01],
[8.2264e-01, 9.6913e-01, 1.0884e-01, 8.5493e-01, 1.8374e-01,
8.6501e-01, 1.6153e-01, 3.0293e-01],
[6.5948e-01, 6.0356e-01, 1.6664e-01, 3.3722e-01, 9.7794e-01,
1.3833e-02, 1.6070e-01, 3.0181e-01]]]]])

时间步t=1的情况:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
x= tensor([[[[2.5490e-01, 1.2639e-01, 5.7598e-01, 9.8435e-01, 2.7988e-01,
8.4105e-01, 5.0002e-01, 7.9076e-01],
[7.1258e-01, 4.4736e-01, 7.8368e-01, 8.3618e-01, 4.0139e-02,
9.7412e-02, 6.7998e-01, 7.1286e-01],
[1.8884e-01, 2.3037e-01, 2.9191e-01, 1.7953e-01, 9.2891e-01,
4.4400e-01, 6.5340e-01, 1.5981e-01],
[2.4112e-01, 4.1668e-01, 6.0014e-01, 4.9383e-01, 3.4121e-01,
2.1231e-01, 7.0532e-01, 6.7256e-01],
[1.9463e-02, 4.5558e-01, 7.2580e-01, 7.8367e-01, 6.1989e-01,
1.6519e-02, 9.6447e-01, 2.3580e-01],
[6.2598e-01, 2.1307e-01, 6.3583e-01, 6.9059e-01, 1.4211e-01,
9.7064e-02, 3.7473e-01, 5.1632e-01],
[6.9992e-01, 7.6289e-01, 3.5203e-01, 4.8954e-01, 8.7898e-01,
7.8275e-03, 5.2165e-01, 7.9674e-01],
[6.0423e-01, 5.1216e-01, 2.0406e-01, 2.4626e-01, 8.3361e-01,
3.1688e-01, 6.5840e-01, 1.3557e-01]],

[[4.7398e-01, 3.7301e-01, 1.1466e-01, 1.8241e-01, 4.3652e-01,
1.6438e-01, 1.9071e-01, 8.9192e-01],
[6.8429e-01, 8.9815e-01, 1.0108e-01, 8.0698e-02, 8.9986e-01,
2.6142e-01, 3.6577e-01, 8.3392e-01],
[5.6779e-01, 1.4391e-01, 3.8700e-01, 5.3148e-01, 8.7907e-01,
3.8305e-02, 2.6470e-01, 4.1808e-01],
[7.0598e-01, 5.7686e-01, 6.4152e-01, 6.2110e-01, 2.7309e-01,
6.1881e-01, 9.8665e-01, 7.9703e-01],
[3.3107e-01, 2.7862e-01, 9.9426e-01, 4.6849e-01, 9.5992e-01,
8.4558e-01, 7.9661e-02, 4.1134e-01],
[5.4624e-02, 9.5138e-01, 1.4058e-01, 9.6501e-01, 3.5145e-01,
2.8402e-01, 3.8676e-01, 6.5260e-01],
[5.2004e-01, 4.7728e-01, 5.5676e-01, 3.6984e-01, 9.6766e-01,
9.6658e-01, 3.1954e-01, 4.4453e-01],
[5.3666e-02, 2.9791e-01, 8.0187e-01, 5.4459e-01, 4.6331e-01,
6.9315e-01, 9.1670e-01, 4.9692e-01]],

[[4.2555e-01, 2.8505e-01, 3.2860e-01, 9.9233e-01, 9.3850e-01,
8.5688e-01, 1.5506e-01, 2.5449e-01],
[9.6989e-01, 5.2753e-01, 7.7811e-01, 5.6960e-01, 3.8716e-01,
7.5363e-01, 9.6678e-01, 3.5747e-01],
[9.7789e-01, 3.2894e-01, 6.8379e-01, 6.4656e-01, 1.5041e-01,
1.1086e-01, 4.0547e-01, 1.8388e-01],
[1.6842e-01, 8.9246e-01, 3.5892e-01, 6.2443e-01, 5.6638e-03,
9.1472e-01, 9.5463e-01, 2.0740e-01],
[5.1649e-01, 3.5757e-01, 1.4834e-01, 5.3815e-01, 4.6948e-01,
8.8748e-01, 6.6702e-01, 4.2158e-01],
[1.2418e-02, 4.4521e-01, 8.2854e-01, 1.9373e-01, 6.4802e-01,
1.6356e-01, 4.1023e-01, 4.1840e-02],
[1.7540e-01, 8.9243e-02, 5.9020e-01, 9.4828e-01, 4.3418e-01,
8.9228e-01, 8.6908e-01, 6.0948e-01],
[3.2250e-01, 1.2573e-01, 4.3231e-01, 8.6035e-01, 2.8213e-01,
1.6248e-04, 9.3926e-01, 6.9541e-02]]]])
y= tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]])

时间步t=2:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
x= tensor([[[[0.2309, 0.9881, 0.1884, 0.2881, 0.0809, 0.4972, 0.4431, 0.4583],
[0.1019, 0.1316, 0.4337, 0.9147, 0.4263, 0.6352, 0.0841, 0.9272],
[0.2552, 0.4428, 0.5045, 0.8411, 0.9764, 0.2762, 0.0722, 0.9259],
[0.2466, 0.1508, 0.6754, 0.2806, 0.7165, 0.3330, 0.5114, 0.9665],
[0.2909, 0.1940, 0.7931, 0.9360, 0.6424, 0.9207, 0.8090, 0.5421],
[0.0914, 0.4560, 0.1832, 0.0137, 0.2024, 0.3541, 0.5356, 0.6091],
[0.1652, 0.8066, 0.3714, 0.9431, 0.9214, 0.6894, 0.5765, 0.9355],
[0.2678, 0.4323, 0.3459, 0.8971, 0.7839, 0.7167, 0.1536, 0.6421]],

[[0.8950, 0.5728, 0.3073, 0.1209, 0.0734, 0.1560, 0.7298, 0.4171],
[0.7007, 0.0835, 0.3588, 0.4511, 0.5231, 0.1092, 0.6216, 0.8706],
[0.3438, 0.5701, 0.7892, 0.9606, 0.3410, 0.7472, 0.9969, 0.7196],
[0.0779, 0.2198, 0.5109, 0.7440, 0.8488, 0.0622, 0.7641, 0.1508],
[0.2719, 0.7112, 0.3531, 0.2715, 0.9424, 0.3412, 0.9919, 0.8189],
[0.8299, 0.5044, 0.8234, 0.5863, 0.6250, 0.0177, 0.5884, 0.3163],
[0.2527, 0.9955, 0.2093, 0.9205, 0.0089, 0.5453, 0.8187, 0.4306],
[0.3028, 0.7453, 0.4733, 0.5465, 0.9768, 0.1062, 0.4326, 0.3740]],

[[0.0844, 0.5555, 0.2830, 0.4017, 0.7249, 0.4641, 0.6197, 0.6087],
[0.8817, 0.4792, 0.8243, 0.2633, 0.8064, 0.1269, 0.3870, 0.9911],
[0.2640, 0.1856, 0.5127, 0.7385, 0.3673, 0.0279, 0.6967, 0.9097],
[0.6796, 0.3095, 0.3271, 0.7176, 0.7491, 0.9546, 0.6302, 0.7134],
[0.6327, 0.8786, 0.5298, 0.9993, 0.5915, 0.6032, 0.5539, 0.1246],
[0.9367, 0.6513, 0.2526, 0.5846, 0.8345, 0.8499, 0.7437, 0.6979],
[0.4374, 0.1352, 0.3312, 0.9168, 0.7277, 0.6815, 0.5732, 0.9398],
[0.5749, 0.7574, 0.0838, 0.3862, 0.0588, 0.2215, 0.6816, 0.0025]]]])
y= tensor([[[[0., 1., 0., 1., 0., 1., 0., 1.],
[0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 1., 1., 0., 0., 1.],
[0., 0., 1., 0., 1., 0., 1., 1.],
[0., 0., 1., 1., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.],
[0., 1., 0., 1., 1., 0., 1., 1.],
[0., 0., 0., 1., 1., 1., 0., 0.]],

[[1., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 1., 0., 0., 1.],
[0., 0., 1., 1., 1., 0., 1., 1.],
[0., 0., 1., 1., 1., 0., 1., 0.],
[0., 0., 1., 0., 1., 1., 1., 1.],
[0., 1., 0., 1., 0., 0., 0., 0.],
[0., 1., 0., 1., 0., 1., 1., 0.],
[0., 1., 1., 1., 1., 0., 1., 0.]],

[[0., 0., 0., 1., 1., 1., 0., 0.],
[1., 1., 1., 0., 1., 0., 1., 1.],
[1., 0., 1., 1., 0., 0., 1., 1.],
[0., 1., 0., 1., 0., 1., 1., 0.],
[1., 1., 0., 1., 1., 1., 1., 0.],
[0., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 0., 0., 1., 0.]]]])

时间步t=3:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
x= tensor([[[[0.8826, 0.4119, 0.6519, 0.2594, 0.5566, 0.8530, 0.5062, 0.3879],
[0.1684, 0.8390, 0.5125, 0.2536, 0.9813, 0.7784, 0.6988, 0.2749],
[0.5936, 0.7382, 0.8773, 0.0871, 0.9786, 0.0229, 0.3883, 0.4908],
[0.6399, 0.3611, 0.8668, 0.7298, 0.2788, 0.0909, 0.1268, 0.0858],
[0.7274, 0.5774, 0.4305, 0.3772, 0.7712, 0.8987, 0.5308, 0.6651],
[0.6519, 0.3986, 0.9995, 0.1238, 0.8075, 0.9713, 0.7373, 0.7475],
[0.2149, 0.4779, 0.2894, 0.4112, 0.6894, 0.9079, 0.8988, 0.1446],
[0.7010, 0.1341, 0.1614, 0.6756, 0.3606, 0.1182, 0.0455, 0.5058]],

[[0.9218, 0.1361, 0.3008, 0.6980, 0.6651, 0.5995, 0.2304, 0.8671],
[0.1256, 0.7588, 0.8898, 0.2950, 0.5548, 0.8092, 0.5355, 0.3615],
[0.5719, 0.0732, 0.0964, 0.9082, 0.0401, 0.7916, 0.9989, 0.7892],
[0.3723, 0.3733, 0.4434, 0.3413, 0.7247, 0.4784, 0.4374, 0.8301],
[0.2002, 0.6506, 0.5349, 0.4056, 0.9835, 0.5532, 0.7012, 0.0547],
[0.1588, 0.5447, 0.7961, 0.9498, 0.7315, 0.6772, 0.4090, 0.1164],
[0.1412, 0.6665, 0.2594, 0.7678, 0.1359, 0.9809, 0.7949, 0.0277],
[0.5339, 0.4439, 0.1055, 0.9942, 0.9716, 0.1570, 0.9232, 0.8486]],

[[0.9062, 0.1599, 0.2680, 0.7011, 0.0011, 0.0850, 0.8769, 0.3341],
[0.7598, 0.4809, 0.3089, 0.6603, 0.8786, 0.3728, 0.1626, 0.5825],
[0.0394, 0.1211, 0.7615, 0.5901, 0.2175, 0.7651, 0.0645, 0.1123],
[0.7714, 0.4492, 0.5084, 0.2521, 0.5732, 0.5625, 0.7174, 0.8568],
[0.3950, 0.1470, 0.9275, 0.9380, 0.7054, 0.4474, 0.6060, 0.4496],
[0.2054, 0.4453, 0.9883, 0.6407, 0.6682, 0.8550, 0.1443, 0.9114],
[0.8226, 0.9691, 0.1088, 0.8549, 0.1837, 0.8650, 0.1615, 0.3029],
[0.6595, 0.6036, 0.1666, 0.3372, 0.9779, 0.0138, 0.1607, 0.3018]]]])
y= tensor([[[[1., 0., 1., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 1., 1., 1., 0.],
[1., 1., 1., 0., 0., 0., 1., 0.],
[1., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 1., 0., 1.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[1., 0., 1., 0., 0., 1., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 1.]],

[[0., 1., 0., 1., 1., 0., 1., 0.],
[0., 1., 1., 0., 0., 1., 1., 0.],
[1., 0., 0., 0., 0., 1., 0., 0.],
[1., 1., 0., 0., 0., 1., 0., 1.],
[0., 1., 0., 1., 0., 0., 0., 0.],
[1., 0., 1., 0., 1., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.]],

[[1., 1., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 1., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 1., 0., 1., 0., 0., 1.],
[0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 1., 0., 0., 0., 1.],
[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 1., 0., 0., 0.]]]])

最终输出的y序列:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
y_seq= tensor([[[[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]],



[[[[0., 1., 0., 1., 0., 1., 0., 1.],
[0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 1., 1., 0., 0., 1.],
[0., 0., 1., 0., 1., 0., 1., 1.],
[0., 0., 1., 1., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.],
[0., 1., 0., 1., 1., 0., 1., 1.],
[0., 0., 0., 1., 1., 1., 0., 0.]],

[[1., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 1., 0., 0., 1.],
[0., 0., 1., 1., 1., 0., 1., 1.],
[0., 0., 1., 1., 1., 0., 1., 0.],
[0., 0., 1., 0., 1., 1., 1., 1.],
[0., 1., 0., 1., 0., 0., 0., 0.],
[0., 1., 0., 1., 0., 1., 1., 0.],
[0., 1., 1., 1., 1., 0., 1., 0.]],

[[0., 0., 0., 1., 1., 1., 0., 0.],
[1., 1., 1., 0., 1., 0., 1., 1.],
[1., 0., 1., 1., 0., 0., 1., 1.],
[0., 1., 0., 1., 0., 1., 1., 0.],
[1., 1., 0., 1., 1., 1., 1., 0.],
[0., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 0., 0., 1., 0.]]]],



[[[[1., 0., 1., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 1., 1., 1., 0.],
[1., 1., 1., 0., 0., 0., 1., 0.],
[1., 0., 0., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 1., 0., 1.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[1., 0., 1., 0., 0., 1., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 1.]],

[[0., 1., 0., 1., 1., 0., 1., 0.],
[0., 1., 1., 0., 0., 1., 1., 0.],
[1., 0., 0., 0., 0., 1., 0., 0.],
[1., 1., 0., 0., 0., 1., 0., 1.],
[0., 1., 0., 1., 0., 0., 0., 0.],
[1., 0., 1., 0., 1., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.]],

[[1., 1., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 1., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 0., 1., 0., 1., 0., 0., 1.],
[0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 1., 0., 0., 0., 1.],
[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 1., 0., 0., 0.]]]]])

multi_step_forward 提供了将 shape = [T, N, *] 的序列数据输入到单步模块进行逐步的前向传播的封装,即将上面的函数进行了封装,使用起来更加方便:

1
2
3
4
5
6
7
8
9
10
11
import torch
from spikingjelly.activation_based import neuron, functional
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]

但是,直接将模块设置成多步模块,其实更为便捷:

1
2
3
4
5
6
7
8
9
10
11
12
import torch
from spikingjelly.activation_based import neuron

net_m = neuron.IFNode(step_mode='m')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = net_m(x_seq)
# y_seq.shape = [T, N, C, H, W]


测试这两个模式输出的结果是不是相同,下面是代码细节:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#单步模式
net_s = neuron.IFNode(step_mode='s')
T = 3
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
#print("x_seq=",x_seq)
for t in range(T):
x = x_seq[t] # x.shape = [N, C, H, W]
y = net_s(x) # y.shape = [N, C, H, W]
#print("x=",x)
#print("y=",y)
#print("T=",T)
y_seq.append(y.unsqueeze(0))

#y_seq是有三个元素[1,1,3,8,8]的列表
y_seq_s = torch.cat(y_seq) #[3,1,3,8,8]
print("y_seq(单步)=",y_seq_s)
# y_seq.shape = [T, N, C, H, W]



# 与多步模式做对比
net_m = neuron.IFNode(step_mode='m')
# T = 4
# N = 1
# C = 3
# H = 8
# W = 8
#x_seq = torch.rand([T, N, C, H, W])
y_seq_m = net_m(x_seq)
print("y_seq(多步)=",y_seq)
is_equal = torch.equal(y_seq_s, y_seq_m)
print("输出是否相同:", is_equal)

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
y_seq(单步)= tensor([[[[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]],



[[[[1., 1., 0., 0., 0., 1., 1., 1.],
[0., 1., 0., 0., 0., 1., 1., 1.],
[1., 0., 1., 1., 1., 0., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 1.],
[0., 0., 0., 0., 1., 1., 0., 1.],
[1., 0., 0., 0., 0., 0., 1., 0.],
[1., 0., 1., 0., 1., 0., 1., 0.],
[1., 0., 0., 1., 0., 1., 1., 1.]],

[[0., 0., 0., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 1., 0., 1., 0., 0., 1.],
[0., 1., 1., 1., 1., 0., 1., 1.],
[1., 0., 1., 0., 0., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 0., 1., 0., 0., 0., 0.],
[1., 0., 1., 1., 0., 0., 0., 0.]],

[[0., 1., 1., 0., 1., 0., 0., 1.],
[1., 1., 1., 1., 0., 0., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 0.],
[1., 1., 1., 1., 0., 0., 1., 1.],
[1., 0., 0., 1., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 1., 1., 1.],
[1., 1., 1., 1., 0., 1., 1., 0.],
[0., 1., 0., 0., 1., 1., 0., 0.]]]],



[[[[0., 0., 0., 1., 1., 0., 0., 0.],
[1., 0., 1., 1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 1., 0.],
[0., 1., 0., 1., 1., 0., 1., 0.],
[1., 1., 0., 1., 0., 0., 1., 0.],
[0., 1., 1., 1., 1., 1., 0., 0.],
[0., 1., 0., 1., 0., 1., 0., 1.],
[0., 1., 0., 0., 1., 0., 0., 0.]],

[[0., 1., 1., 1., 1., 0., 0., 0.],
[0., 1., 1., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 1., 0., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1., 0., 1.],
[0., 0., 0., 0., 1., 1., 0., 1.]],

[[1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 1., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 1.],
[1., 1., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 1.],
[1., 0., 1., 0., 0., 0., 0., 1.]]]]])
y_seq(多步)= [tensor([[[[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]]]]]), tensor([[[[[1., 1., 0., 0., 0., 1., 1., 1.],
[0., 1., 0., 0., 0., 1., 1., 1.],
[1., 0., 1., 1., 1., 0., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 1.],
[0., 0., 0., 0., 1., 1., 0., 1.],
[1., 0., 0., 0., 0., 0., 1., 0.],
[1., 0., 1., 0., 1., 0., 1., 0.],
[1., 0., 0., 1., 0., 1., 1., 1.]],

[[0., 0., 0., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 1., 0., 1., 0., 0., 1.],
[0., 1., 1., 1., 1., 0., 1., 1.],
[1., 0., 1., 0., 0., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 0., 1., 0., 0., 0., 0.],
[1., 0., 1., 1., 0., 0., 0., 0.]],

[[0., 1., 1., 0., 1., 0., 0., 1.],
[1., 1., 1., 1., 0., 0., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 0.],
[1., 1., 1., 1., 0., 0., 1., 1.],
[1., 0., 0., 1., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 1., 1., 1.],
[1., 1., 1., 1., 0., 1., 1., 0.],
[0., 1., 0., 0., 1., 1., 0., 0.]]]]]), tensor([[[[[0., 0., 0., 1., 1., 0., 0., 0.],
[1., 0., 1., 1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 1., 0.],
[0., 1., 0., 1., 1., 0., 1., 0.],
[1., 1., 0., 1., 0., 0., 1., 0.],
[0., 1., 1., 1., 1., 1., 0., 0.],
[0., 1., 0., 1., 0., 1., 0., 1.],
[0., 1., 0., 0., 1., 0., 0., 0.]],

[[0., 1., 1., 1., 1., 0., 0., 0.],
[0., 1., 1., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 1., 0., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 1., 1., 0., 1.],
[0., 0., 0., 0., 1., 1., 0., 1.]],

[[1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 1., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 1.],
[1., 1., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 1.],
[1., 0., 1., 0., 0., 0., 0., 1.]]]]])]
输出是否相同: True


4. 状态保存与重置

SNN中的神经元等模块,与RNN类似,带有隐藏状态,其输出y[t]不仅仅与当前时刻的输入x[t]有关,还与上一个时末的状态h[t-1]有关,即y[t]=f(x[t],h[t-1])。

PyTorch的设计为RNN将状态也一并输出,可以参考torch.nn.RNN的API文档。而在spikingjelly.activation_based中,状态会被保存在模块内部。例如,我们新建一层IF神经元,设置为单步模式,查看给与输入前的默认电压,和给与输入后的电压:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from spikingjelly.activation_based import neuron

net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(net_s)
print(f'the initial v={net_s.v}')
y = net_s(x)
print(f'x={x}')
print(f'y={y}')
print(f'v={net_s.v}')

# outputs are:

'''
IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
the initial v=0.0
x=tensor([0.5543, 0.0350, 0.2171, 0.6740])
y=tensor([0., 0., 0., 0.])
v=tensor([0.5543, 0.0350, 0.2171, 0.6740])
'''

在初始化后,IF神经元层的v会被设置为0,首次给与输入后v会自动广播到与输入相同的shape

若我们给与一个新的输入,则应该先清除神经元之前的状态,让其恢复到初始化状态,可以通过调用模块的self.reset()函数实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from spikingjelly.activation_based import neuron

net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(f'check point 0: v={net_s.v}')
y = net_s(x)
print(f'check point 1: v={net_s.v}')
net_s.reset()
print(f'check point 2: v={net_s.v}')
x = torch.rand([8])
y = net_s(x)
print(f'check point 3: v={net_s.v}')

# outputs are:

'''
check point 0: v=0.0
check point 1: v=tensor([0.9775, 0.6598, 0.7577, 0.2952])
check point 2: v=0.0
check point 3: v=tensor([0.8728, 0.9031, 0.2278, 0.5089, 0.1059, 0.0479, 0.5008, 0.8530])
'''

方便起见,还可以通过调用spikingjelly.activation_based.functional.reset_net将整个网络中的所有有状态模块进行重置。

若网络使用了有状态的模块,在训练和推理时,务必在处理完毕一个batch的数据后进行重置:

1
2
3
4
5
6
7
8
9
10
11
12
from spikingjelly.activation_based import functional
# ...
for x, label in tqdm(train_data_loader):
# ...
optimizer.zero_grad()
y = net(x)
loss = criterion(y, label)
loss.backward()
optimizer.step()

functional.reset_net(net) #重置网络状态
# Never forget to reset the network!

如果忘了重置,在推理时可能输出错误的结果,而在训练时则会直接报错:

1
2
3
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). 
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

5. 传播模式

若一个网络全部由单步模块构成,则整个网络的计算顺序是按照逐步传播(step-by-step)的模式进行,例如:

1
2
3
4
5
6
for t in range(T):
x = x_seq[t]
y = net(x) #将输入数据传递给网络
y_seq_step_by_step.append(y.unsqueeze(0))

y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)

如果网络全部由多步模块构成,则整个网络的计算顺序是按照逐层传播(layer-by-layer)的模式进行,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4 #时间步
N = 2 # batchsize
C = 8 # 输入特征数量
x_seq = torch.rand([T, N, C]) * 64.

net = nn.Sequential(
layer.Linear(C, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)

functional.set_step_mode(net, step_mode='m')
with torch.no_grad():
y_seq_layer_by_layer = x_seq
for i in range(net.__len__()):
y_seq_layer_by_layer = net[i](y_seq_layer_by_layer)

在绝大多数情况下我们不需要显式的实现 for i in range(net.__len__()) 这样的循环,因为 torch.nn.Sequential 已经帮我们实现过了,因此实际上我们可以这样做:

1
y_seq_layer_by_layer = net(x_seq)

逐步传播和逐层传播,实际上只是计算顺序不同,它们的计算结果是完全相同的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4
N = 2
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W]) * 64.

net = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, stride=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
neuron.IFNode(),
layer.Flatten(start_dim=1),
layer.Linear(8 * H // 2 * W // 2, 10),
neuron.IFNode(),
)

print(f'net={net}') #打印网络信息

with torch.no_grad():
y_seq_step_by_step = []
for t in range(T):
x = x_seq[t]
y = net(x)
y_seq_step_by_step.append(y.unsqueeze(0))

y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)
# we can also use `y_seq_step_by_step = functional.multi_step_forward(x_seq, net)` to get the same results

print(f'y_seq_step_by_step=\n{y_seq_step_by_step}')

functional.reset_net(net)
functional.set_step_mode(net, step_mode='m')
y_seq_layer_by_layer = net(x_seq)

max_error = (y_seq_layer_by_layer - y_seq_step_by_step).abs().max()
print(f'max_error={max_error}') #表示两次输出是没有差别的

上面这段代码的输出为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
net=Sequential(
(0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=s)
(3): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(4): Flatten(start_dim=1, end_dim=-1, step_mode=s)
(5): Linear(in_features=128, out_features=10, bias=True)
(6): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
y_seq_step_by_step=
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

[[0., 1., 0., 0., 0., 0., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]],

[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 1., 0., 0., 1., 0., 0., 0.]],

[[0., 1., 0., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]]])
max_error=0.0

下面的图片展示了逐步传播构建计算图的顺序:


下面的图片展示了逐层传播构建计算图的顺序:


SNN的计算图有2个维度,分别是时间步数和网络深度,网络的传播实际上就是生成完整计算图的过程,正如上面的2张图片所示。实际上,逐步传播是深度优先遍历,而逐层传播是广度优先遍历。

尽管两者区别仅在于计算顺序,但计算速度和内存消耗上会略有区别。

  • 在使用梯度替代法训练时,通常推荐使用逐层传播。在正确构建网络的情况下,逐层传播的并行度更大,速度更快

  • 在内存受限时使用逐步传播,例如ANN2SNN任务中需要用到非常大的T。因为在逐层传播模式下,对无状态的层而言,真正的batch size是TN而不是N(参见下一个教程),当T太大时内存消耗极大