1.继承是什么?
如下定义一个动物类Animal为基类,它基本两个实例属性name和age、一个方法call。
1 | class Animal(object): # python3中所有类都可以继承于object基类 |
解释:
- training (bool) - 指示模块当前是训练还是评估模式
- add_module() - 添加子模块
- apply() - 递归地将函数应用于所有子模块
- buffers() - 返回模块 buffer 的迭代器
- children() - 返回直接子模块的迭代器
- cpu()/cuda()/etc. - 将模块移动到相应设备
- double()/float()/etc. - 将模块参数和 buffer 转换为相应数据类型
- eval() - 将模块设为评估模式
- forward() - 定义前向传播计算,所有子类需要重写
- register_buffer() - 向模块添加 buffer
- register_parameter() - 向模块添加参数
- state_dict() - 返回模块状态的字典表示
- load_state_dict() - 从字典中加载模块状态
- parameters()/named_parameters() - 返回可训练参数的迭代器
- modules()/named_modules() - 返回所有子模块的迭代器
- zero_grad() - 将所有参数的梯度设为0
- train()/eval() - 设置模块为训练/评估模式
3.注意技巧
我们一般定义自己的网络的时候,会继承这个nn.Moudle,并重新构造__init__和forward这两个def,但有一些技巧需要注意:
- 将具有可学习参数的层放在构造函数
__init__中 - foward方法必须重写,实现各个层连接
1 |
|