1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| class ResNet_104(nn.Module): def __init__(self, block, num_block, num_classes=1000): super().__init__() k = 1 self.in_channels = 64 * k self.conv1 = nn.Sequential( Snn_Conv2d(3, 64 * k, kernel_size=3, padding=1, stride=2), Snn_Conv2d(64 * k, 64 * k, kernel_size=3, padding=1, stride=1), Snn_Conv2d(64 * k, 64 * k, kernel_size=3, padding=1, stride=1), batch_norm_2d(64 * k), ) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.mem_update = mem_update() self.conv2_x = self._make_layer(block, 64 * k, num_block[0], 2) self.conv3_x = self._make_layer(block, 128 * k, num_block[1], 2) self.conv4_x = self._make_layer(block, 256 * k, num_block[2], 2) self.conv5_x = self._make_layer(block, 512 * k, num_block[3], 2) self.fc = nn.Linear(512 * block.expansion * k, num_classes) self.dropout = nn.Dropout(p=0.2)
def _make_layer(self, block, out_channels, num_blocks, stride): """构建Resnet模型中的一个残差块组"""
strides = [stride] + [1] * (num_blocks - 1) layers = []
for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x): input = torch.zeros(time_window, x.size()[0], 3, x.size()[2], x.size()[3], device=device) for i in range(time_window): input[i] = x output = self.conv1(input) output = self.conv2_x(output) output = self.conv3_x(output) output = self.conv4_x(output) output = self.conv5_x(output) output = self.mem_update(output) output = F.adaptive_avg_pool3d(output, (None, 1, 1)) output = output.view(output.size()[0], output.size()[1], -1) output = output.sum(dim=0) / output.size()[0] output = self.dropout(output) output = self.fc(output) return output
|