Loading...

容器

SpikingJelly中主要提供了如下几种容器:

  • 函数风格的multi_step_forward和模块风格的MultiStepContainer

  • 函数风格的seq_to_ann_forward和模块风格的SeqToANNContainer

  • 对单步模块进行包装以进行单步/多步传播的StepModeContainer


1. multi_step_forward可以将一个单步模块进行多步传播,而MultiStepContainer则可以将一个单步模块包装成多步模块,

functional.multi_step_forward: 这是一个函数,直接接收一个输入序列和一个神经元节点,然后在内部处理整个时间序列。

layer.MultiStepContainer: 这是一个容器,它包装一个或多个神经元节点,使其能够作为一个整体来处理整个输入时间序列。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from spikingjelly.activation_based import neuron, functional, layer

net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]

net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]

# z_seq is identical to y_seq

对于无状态的ANN网络层,例如 torch.nn.Conv2d,其本身要求输入数据的 shape = [N, *],若用于多步模式,则可以用多步的包装器进行包装:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])

conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)

y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]

net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]

# z_seq is identical to y_seq

但是ANN的网络层本身是无状态的,不存在前序依赖,没有必要在时间上串行的计算,可以使用函数风格的seq_to_ann_forward或模块风格的SeqToANNContainer进行包装。seq_to_ann_forwardshape = [T, N, *]的数据首先变换为shape = [TN, *],再送入无状态的网络层进行计算,输出的结果会被重新变换为shape = [T, N, *]。不同时刻的数据是并行计算的,因而速度更快:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])

conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)

y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]

net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]

# z_seq is identical to y_seq

p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
# p_seq.shape = [T, N, 8, H, W]

net = layer.SeqToANNContainer(conv, bn)
q_seq = net(x_seq)
# q_seq.shape = [T, N, 8, H, W]

# q_seq is identical to p_seq, and also identical to y_seq and z_seq

2. 常用的网络层

spikingjelly.activation_based.layer 已经定义过,更推荐使用 spikingjelly.activation_based.layer 中的网络层,而不是使用 SeqToANNContainer 手动包装,尽管 spikingjelly.activation_based.layer 中的网络层实际上就是用包装器包装 forward 函数实现的。spikingjelly.activation_based.layer 中的网络层,优势在于:

  • 支持单步和多步模式,而 SeqToANNContainerMultiStepContainer 包装的层,只支持多步模式

  • 包装器会使得 state_dictkeys() 也增加一层包装,给加载权重带来麻烦

例如

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron


ann = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU()
)

print(f'ann.state_dict.keys()={ann.state_dict().keys()}')

net_container = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
),
neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')

net_origin = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')

try:
print('net_container is trying to load state dict from ann...')
net_container.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_container can not load! The error message is\n', e)

try:
print('net_origin is trying to load state dict from ann...')
net_origin.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_origin can not load! The error message is', e)

我们可以看出1和3的键名是匹配的,2和上面两者不同,所以前两者加载才能成功

输出为:

1
2
3
4
5
6
7
8
9
10
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!

MultiStepContainerSeqToANNContainer 都是只支持多步模式的,不允许切换为单步模式。


3. StepModeContainer 类似于融合版的 MultiStepContainerSeqToANNContainer,可以用于包装无状态或有状态的单步模块,需要在包装时指明是否有状态,但此包装器还支持切换单步和多步模式。

包装无状态层的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from spikingjelly.activation_based import neuron, layer


with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
False,
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]

net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]

包装有状态层的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from spikingjelly.activation_based import neuron, layer, functional


with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)

net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
functional.reset_net(net)

使用 set_step_mode 改变 StepModeContainer 是安全的,只会改变包装器本身的 step_mode,而包装器内的模块仍然保持单步:

1
2
3
4
5
6
7
8
9
10
11
12
import torch
from spikingjelly.activation_based import neuron, layer, functional


with torch.no_grad():
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
functional.set_step_mode(net, 'm')
print(f'net.step_mode={net.step_mode}')
print(f'net[0].step_mode={net[0].step_mode}')

输出结果:

1
2
net.step_mode=m
net[0].step_mode=s

如果模块本身就支持单步和多步模式的切换,则不推荐使用 MultiStepContainerStepModeContainer 对其进行包装。因为包装器使用的多步前向传播,可能不如模块自身定义的前向传播速度快。

通常需要用到 MultiStepContainerStepModeContainer 的是一些没有定义多步的模块,例如一个在 torch.nn 中存在,但在 spikingjelly.activation_based.layer 中不存在的网络层。