try: print('net_container is trying to load state dict from ann...') net_container.load_state_dict(ann.state_dict()) print('Load success!') except BaseException as e: print('net_container can not load! The error message is\n', e)
try: print('net_origin is trying to load state dict from ann...') net_origin.load_state_dict(ann.state_dict()) print('Load success!') except BaseException as e: print('net_origin can not load! The error message is', e)
我们可以看出1和3的键名是匹配的,2和上面两者不同,所以前两者加载才能成功
输出为:
1 2 3 4 5 6 7 8 9 10
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked']) net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked']) net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked']) net_container is trying to load state dictfrom ann... net_container can not load! The error message is Error(s) in loading state_dict for Sequential: Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var". Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked". net_origin is trying to load state dictfrom ann... Load success!
import torch from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad(): T = 4 N = 2 C = 4 H = 8 W = 8 x_seq = torch.rand([T, N, C, H, W]) net = layer.StepModeContainer( True, neuron.IFNode() ) net.step_mode = 'm' y_seq = net(x_seq) # y_seq.shape = [T, N, C, H, W] functional.reset_net(net)