梯度替代
1.提出问题
在神经元章节中我们已经提到过,描述神经元放电过程的 S[t]=Θ(H[t]−Vthreshold),使用了一个Heaviside阶跃函数:
Θ(x)={1,0,x≥0x<0
按照定义,其导数为冲激函数:
δ(x)={+∞,0,x=0x=0
直接使用冲激函数进行梯度下降,显然会使得网络的训练及其不稳定。为了解决这一问题,各种梯度替代法(the surrogate gradient method)被相继提出。
2.代理梯度函数的应用
替代函数在神经元中被用于生成脉冲,查看 BaseNode.neuronal_fire 的源代码可以发现:
1 2 3 4 5 6 7 8 9
| class BaseNode(base.MemoryModule): def __init__(..., surrogate_function: Callable = surrogate.Sigmoid(), ...) self.surrogate_function = surrogate_function
def neuronal_fire(self): return self.surrogate_function(self.v - self.v_threshold)
|
3.梯度替代法原理
梯度替代法的原理:在前向传播时使用 y=Θ(x),而在反向传播时则使用 dxdy=σ′(x),而非
dxdy=Θ′(x)
- 其中 σ(x) 即为替代函数。σ(x) 通常是一个形状与 Θ(x)类似,但光滑连续的函数。
4.使用替代函数
在 spikingjelly.activation_based.surrogate 中提供了一些常用的替代函数,其中Sigmoid函数 σ(x,α)=1+exp(−αx)1为 Sigmoid,下图展示了原始的Heaviside阶跃函数 Heaviside
、 alpha=5
时的Sigmoid原函数 Primitive
以及其梯度 Gradient
:
替代函数的使用比较简单,使用替代函数就像是使用函数一样:
1 2 3 4 5 6 7 8 9 10 11 12
| import torch from spikingjelly.activation_based import surrogate
sg = surrogate.Sigmoid(alpha=4.) x = torch.rand([8]) - 0.5 x.requires_grad = True y = sg(x) y.sum().backward()
print(f'x={x}') print(f'y={y}') print(f'x.grad={x.grad}')
|
输出为:
1 2 3 4
| x=tensor([-0.1303, 0.4976, 0.3364, 0.4296, 0.2779, 0.4580, 0.4447, 0.2466], requires_grad=True) y=tensor([0., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<sigmoidBackward>) x.grad=tensor([0.9351, 0.4231, 0.6557, 0.5158, 0.7451, 0.4759, 0.4943, 0.7913])
|
5.API风格说明
每个替代函数,除了有形如 Sigmoid 的模块风格API,也提供了形如 sigmoid 函数风格的API。
模块风格的API使用驼峰命名法,而函数风格的API使用下划线命名法,关系类似于 torch.nn
和 torch.nn.functional
,下面是几个示例:
模块 |
函数 |
Sigmoid |
sigmoid |
SoftSign |
soft_sign |
LeakyKReLU |
leaky_k_relu |
下面是函数风格API的用法示例:
1 2 3 4 5 6 7 8 9 10 11 12
| import torch from spikingjelly.activation_based import surrogate
alpha = 4. x = torch.rand([8]) - 0.5 x.requires_grad = True y = surrogate.sigmoid.apply(x, alpha) y.sum().backward()
print(f'x={x}') print(f'y={y}') print(f'x.grad={x.grad}')
|
6.替代函数的超参数
替代函数通常会有1个或多个控制形状的超参数,例如spikingjelly.activation_based.surrogate.Sigmoid 中的 alpha
。SpikingJelly中替代函数的形状参数,默认情况下是使得替代函数梯度最大值为1,这在一定程度上可以避免梯度累乘导致的梯度爆炸问题。