Loading...

1. class BertLayer(nn.Module):

初始化函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BertAttention(config, position_embedding_type="absolute")
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

前向传播函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)#这个就只有一个输出,后面是无效的
attention_output = self_attention_outputs[0]

# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights

# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value

layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)#这部分作用,好像就是扩大了矩阵,增加了过呢更多
outputs = (layer_output,) + outputs

# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)

# print(layer_output.shape)

return outputs

前馈网络块:

1
2
3
4
def feed_forward_chunk(self, attention_output):#是否有必要????
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output

self_attn_past_key_value: None
self_attention_output: 只有一个元素的tuple
attention_output: tensor[4,32,128,1536]
layer_out:

  • —Intermediate—
  • tensor[4,32,128,1536]->(Intermediate=SpikeLinear).PSN->tensor[4,32,128,1536]的0,1向量
  • tensor[4,32,128,1536]的0,1向量 ->(Intermediate=SpikeLinear).nn.functional.linear + self.bias->(Intermediate=SpikeLinear).scale->tensor[4,32,128,3072]
  • —output—
  • tensor[4,32,128,3072]->output.SpikeLinear.PSN->tensor[4,32,128,3072]的0,1向量
  • tensor[4,32,128,3072]的0,1向量->>output.SpikeLinear.nn.functional.linear + self.bias->>output.SpikeLinear.scale->tensor[4,32,128,1536]
  • tensor[4,32,128,1536]->output.dropout->tensor[4,32,128,1536]
  • tensor[4,32,128,1536]->output.LayerNorm->tensor[4,32,128,1536]

outputs: 只有一个layer_out的tuple




2. class BertEncoder(nn.Module):

初始化函数:

1
2
3
4
5
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

前向传播函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
# all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

# if self.gradient_checkpointing and self.training:
# if use_cache:
# logger.warning_once(
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
# )
# use_cache = False

next_decoder_cache = () if use_cache else None


######################
hidden_states = hidden_states.repeat(tuple([4] + torch.ones(len(hidden_states.size()), dtype=int).tolist())) # T B L D
# hidden_states = hidden_states.transpose(0, 1) # B T L D
######################
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

# if self.gradient_checkpointing and self.training:

# def create_custom_forward(module):
# def custom_forward(*inputs):
# return module(*inputs, past_key_value, output_attentions)

# return custom_forward

# layer_outputs = torch.utils.checkpoint.checkpoint(
# create_custom_forward(layer_module),
# hidden_states,
# attention_mask,
# layer_head_mask,
# encoder_hidden_states,
# encoder_attention_mask,
# )
# else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)#不执行
# if self.config.add_cross_attention:
# all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.mean(0),)
#########################
hidden_states = hidden_states.mean(0)
#########################

if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
None,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, ########################## all_hidden_states
attentions=all_self_attentions,
cross_attentions=None,
)

all_hidden_states:()
all_self_attentions:None
next_decoder_cache:None
hidden_states:tensor[32,128,1536]->repeat->tensor[4,32,128,1536]

—BertLayer—
self_attn_past_key_value: None
self_attention_output: 只有一个元素的tuple
attention_output: tensor[4,32,128,1536]
layer_out:

  • —Intermediate—
  • tensor[4,32,128,1536]->(Intermediate=SpikeLinear).PSN->tensor[4,32,128,1536]的0,1向量
  • tensor[4,32,128,1536]的0,1向量 ->(Intermediate=SpikeLinear).nn.functional.linear + self.bias->(Intermediate=SpikeLinear).scale->tensor[4,32,128,3072]
  • —output—
  • tensor[4,32,128,3072]->output.SpikeLinear.PSN->tensor[4,32,128,3072]的0,1向量
  • tensor[4,32,128,3072]的0,1向量->>output.SpikeLinear.nn.functional.linear + self.bias->>output.SpikeLinear.scale->tensor[4,32,128,1536]
  • tensor[4,32,128,1536]->output.dropout->tensor[4,32,128,1536]
  • tensor[4,32,128,1536]->output.LayerNorm->tensor[4,32,128,1536]

outputs: 只有一个layer_out的tuple
—BertLayer—

layer_outputs=BertLayer.outputs:只有一个元素tensor[4,32,128,1536]的元组
hidden_states = layer_outputs的第一个元素tensor[4,32,128,1536]
hidden_states: tensor[4,32,128,1536]->mean(0)->tensor[32,128,1536]


3. PSN

神经元的充电状态:

H=WX,WRT×T,XRT×NH = WX, \quad W \in \mathbb{R}^{T \times T}, X \in \mathbb{R}^{T \times N}

这里,HH 表示神经元的状态,WW 是可学习的权重矩阵,XX 是输入。

1
h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1))

脉冲生成:

S=Θ(HB),BRT,S{0,1}T×NS = \Theta(H - B), \quad B \in \mathbb{R}^{T}, S \in \{0, 1\}^{T \times N}

在这里,SS 表示生成的脉冲序列,Θ\Theta 是阈值函数(比如 heaviside 函数),BB 是可学习的阈值。

1
spike_seq = self.surrogate_function(h_seq)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
"""
:param T: the number of time-steps
:type T: int
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

.. admonition:: Note
:class: note

The PSN only supports the multi-step mode.
"""
super().__init__()
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

#将权重和偏置初始化
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

def forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
#print(x_seq.shape)#[4,32,128,1536]/[4, 32, 128, 3072]
#print(x_seq.flatten(1).shape)#[4,32x128x1536]/[4,32x128x3072]
h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1))#执行矩阵乘法和加法
#h_seq.shape=[4,32x128x1536]/[4,32x128x3072]

spike_seq = self.surrogate_function(h_seq)#形状没变,只是转化成了0,1的脉冲的形式


return spike_seq.view(x_seq.shape)#将spike_seq的形状改为x_seq的形状

def extra_repr(self):
return super().extra_repr() + f', T={self.T}'

4. MaskedPSN

神经元的充电状态:

H=(WMk)X,WRT×T,MkRT×T,XRT×NH = (W \cdot M_k)X, \quad W \in \mathbb{R}^{T \times T}, M_k \in \mathbb{R}^{T \times T}, X \in \mathbb{R}^{T \times N}

这里 HH 表示神经元的状态,WW 是可学习的权重矩阵,MkM_k 是掩码矩阵,XX 是输入。

1
h_seq = torch.addmm(self.bias, self.masked_weight(), x_seq.flatten(1))

脉冲生成:

S=Θ(HB),BRT,S{0,1}T×NS = \Theta(H - B), \quad B \in \mathbb{R}^{T}, S \in \{0, 1\}^{T \times N}

在这里,SS 表示生成的脉冲序列,Θ\Theta 是阈值函数(例如 heaviside 函数),BB 是可学习的阈值。

1
spike_seq = self.surrogate_function(h_seq).view(x_seq.shape)

掩码矩阵 MkM_k

  • MkM_k 的定义基于其元素的位置关系:

Mk[i][j]={1,if jij+k10,otherwise M_k[i][j] = \begin{cases} 1, & \text{if } j \leq i \leq j + k - 1 \\ 0, & \text{otherwise} \end{cases}

这意味着 MkM_k 的特定元素是基于其在矩阵中的位置来决定的,这种结构允许对权重矩阵 WW 进行局部调整或掩码。

1
mask0 = torch.tril(mask1) * torch.triu(mask1, -(self.k - 1))

λ\lambda 和渐进掩码过程

  • λ\lambda 用于调整掩码过程,定义了一个渐进的掩码矩阵 Mk(λ)M_k(\lambda):

Mk(λ)=λMk+(1λ)JM_k(\lambda) = \lambda \cdot M_k + (1 - \lambda) \cdot J

这里 JJ 是一个全一矩阵。通过调整 λ\lambda 的值,可以在完全掩码 MkM_k 和全连接(全一矩阵 JJ)之间平滑地过渡。

1
2
def gen_masked_weight(lambda_: torch.Tensor, mask0: torch.Tensor, mask1: torch.Tensor, weight: torch.Tensor):
return (lambda_ * mask0 + (1. - lambda_) * mask1) * weight

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class MaskedPSN(base.MemoryModule):
@staticmethod
@torch.jit.script
def gen_masked_weight(lambda_: torch.Tensor, mask0: torch.Tensor, mask1: torch.Tensor, weight: torch.Tensor):
return (lambda_ * mask0 + (1. - lambda_) * mask1) * weight

def masked_weight(self):
if self.lambda_ >= 1.:
return self.weight * self.mask0
else:
return self.gen_masked_weight(self.lambda_, self.mask0, self.mask1, self.weight)

def __init__(self, k: int, T: int, lambda_init: float = 0.,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's'):
"""
:param k: the order of the Masked PSN
:type k: int

:param T: the number of time-steps
:type T: int

:param lambda_init: the initial value of :math:`\\lambda` to adjust the progressive masking process
:type lambda_init: float

:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str

.
.. admonition:: Note
:class: note

The masked PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.

"""
super().__init__()
self.register_memory('time_step', 0)
self.register_memory('queue', [])
self.step_mode = step_mode
self.k = k
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])
self.register_buffer('_lambda_', torch.as_tensor(lambda_init))

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

mask1 = torch.ones([T, T])
mask0 = torch.tril(mask1) * torch.triu(mask1, -(self.k - 1))
self.register_buffer('mask0', mask0)
self.register_buffer('mask1', mask1)


def single_step_forward(self, x: torch.Tensor):
if self.lambda_ < 1.:
raise ValueError("The masked PSN can not work in single-step mode when k < 1!")

self.queue.append(x.flatten())
if self.queue.__len__() > self.k:
self.queue.pop(0)

if self.time_step + 1 > self.T:
raise OverflowError(f"The MaskedPSN(T={self.T}) has run {self.time_step + 1} time-steps!")


weight = self.masked_weight()[self.time_step, self.time_step + 1 - self.queue.__len__(): self.time_step + 1]
x_seq = torch.stack(self.queue)



for i in range(x.dim()):
weight = weight.unsqueeze(-1)


h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias[self.time_step])

self.time_step += 1
return spike.view(x.shape)

def multi_step_forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
assert x_seq.shape[0] == self.T
h_seq = torch.addmm(self.bias, self.masked_weight(), x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq).view(x_seq.shape)
return spike_seq

@property
def lambda_(self):
return self._lambda_

@lambda_.setter
def lambda_(self, value: float):
torch.fill_(self.lambda_, value)

def extra_repr(self):
return super().extra_repr() + f', lambda_={self.lambda_}, T={self.T}'

5. SlidingPSN

H[t]=i=0k1WiX[tk+1+i]H[t] = \sum_{i=0}^{k-1}W_{i}\cdot X[t - k + 1 + i]

S[t]=Θ(H[t]B)S[t] = \Theta(H[t] - B)

其中,WW 是可学习的权重,BB 是可学习的阈值,Θ\Theta 是阈值函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class SlidingPSN(base.MemoryModule):

@property
def supported_backends(self):
return 'gemm', 'conv'

def gen_gemm_weight(self, T: int):
weight = torch.zeros([T, T], device=self.weight.device)
for i in range(T):
end = i + 1
start = max(0, i + 1 - self.k)
length = min(end - start, self.k)
weight[i][start: end] = self.weight[self.k - length: self.k]

return weight

def __init__(self, k: int, exp_init: bool = True,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's',
backend: str = 'gemm'):
"""
:param k: the order of the Sliding PSN
:type k: int

:param exp_init: if ``True``, the weight will be initialized as ``(..., 1/4, 1/2, 1)``. If ``False``, the weight will be initialized by the kaiming uniform
:type exp_init: bool

:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str

:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str



.. admonition:: Note
:class: note

The Sliding PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.


"""

super().__init__()
self.register_memory('queue', [])
self.step_mode = step_mode
self.k = k
self.surrogate_function = surrogate_function
self.backend = backend

if exp_init:
weight = torch.ones([k])
for i in range(k - 2, -1, -1):
weight[i] = weight[i + 1] / 2.
else:
weight = torch.ones([1, k])
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
weight = weight[0]

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.as_tensor(-1.))

def single_step_forward(self, x: torch.Tensor):
self.queue.append(x.flatten())
if self.queue.__len__() > self.k:
self.queue.pop(0)

weight = self.weight[self.k - self.queue.__len__(): self.k]
x_seq = torch.stack(self.queue)

for i in range(x.dim()):
weight = weight.unsqueeze(-1)

h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias)

return spike.view(x.shape)

def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'gemm':

weight = self.gen_gemm_weight(x_seq.shape[0])
h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)
return self.surrogate_function(h_seq)
elif self.backend == 'conv':

# x_seq.shape = [T, N, *]
x_seq_shape = x_seq.shape
# [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
x_seq = x_seq.flatten(1).t().unsqueeze(1)

x_seq = F.pad(x_seq, pad=(self.k - 1, 0))
x_seq = F.conv1d(x_seq, self.weight.view(1, 1, -1), stride=1)

x_seq = x_seq.squeeze(1).t().view(x_seq_shape)
return self.surrogate_function(x_seq + self.bias)

else:
raise NotImplementedError(self.backend)

def extra_repr(self):
return super().extra_repr() + f', order={self.k}'

6. 内存使用情况

1
self.lif =  neuron.PSN(T=128)

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 6; 23.70 GiB total capacity; 22.28 GiB already allocated; 231.69 MiB free; 22.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


1
self.lif =  neuron.SlidingPSN(k=4,exp_init=False,step_mode='m',backend='conv')

虽然GPU总容量为23.70 GiB,但在已分配20.84 GiB内存的情况下,仅剩1.41 GiB的空闲内存,无法满足额外的内存需求


1
self.lif =  neuron.SlidingPSN(k=5,exp_init=False,step_mode='m',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 4; 23.70 GiB total capacity; 20.87 GiB already allocated; 1.39 GiB free; 21.00 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


1
self.lif =  neuron.SlidingPSN(k=10,exp_init=False,step_mode='m',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 0; 23.70 GiB total capacity; 20.98 GiB already allocated; 1.27 GiB free; 21.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

暂时可以得出结论:k增大,越消耗内存


1
self.lif =  neuron.SlidingPSN(k=3,exp_init=False,step_mode='m',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 3; 23.70 GiB total capacity; 20.82 GiB already allocated; 1.43 GiB free; 20.96 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

越来越接近需要分配的内存了


1
self.lif =  neuron.SlidingPSN(k=3,exp_init=False,step_mode='m',backend='gemm')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 7; 23.70 GiB total capacity; 22.27 GiB already allocated; 235.69 MiB free; 22.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

可以看出后端设为gemm非常消耗内存


1
self.lif =  neuron.SlidingPSN(k=3,exp_init=True,step_mode='m',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 3; 23.70 GiB total capacity; 20.82 GiB already allocated; 1.43 GiB free; 20.96 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

可也看出exp_init对内存无影响


1
self.lif =  neuron.SlidingPSN(k=3,exp_init=False,step_mode='s',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB (GPU 4; 23.70 GiB total capacity; 22.27 GiB already allocated; 235.69 MiB free; 22.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

可以看出单步模式非常占内存


以上都是bz=16的结果,下面是bz=32的结果

1
self.lif =  neuron.SlidingPSN(k=3,exp_init=True,step_mode='m',backend='conv')

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 3; 23.70 GiB total capacity; 19.34 GiB already allocated; 2.94 GiB free; 19.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF