pytorch笔记 pytorch模型中的parameter与buffer
生活随笔
收集整理的这篇文章主要介绍了
pytorch笔记 pytorch模型中的parameter与buffer
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
1 模型的两种参数
在 Pytorch 中一种模型保存和加载的方式如下:(具体见pytorch模型的保存与加载_刘文巾的博客-CSDN博客)
#save torch.save(net.state_dict(),PATH)#load model=MyModel(*args,**kwargs) model.load_state_dict(torch.load(PATH)) model.eval模型保存的是 net.state_dict() 的返回对象。
net.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数
上例模型中的参数就是线性层的 weight 和 bias.
模型中需要保存下来的参数包括两种:
- 一种是反向传播需要被optimizer更新的,称之为 parameter
- 一种是反向传播不需要被optimizer更新,称之为 buffer
第一种参数我们可以通过 model.parameters() 返回;
第二种参数我们可以通过 model.buffers() 返回。
因为我们的模型保存的是 state_dict 返回的 OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict。
2 Parameter
Parameter参数有两种创建方式:
像我们前面的nn.Conv1d,nn.Linear,nn.RNN等模型,里面的权重参数等会被自动认为是Parameter 参数
3 buffer
buffer参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,可以通过model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。
4 为什么要注册
为什么不直接将不需要进行参数修改的变量作为模型类的成员变量就好了,还要进行注册?
5 实例说明
import torch class net(torch.nn.Module):def __init__(self):super(net,self).__init__()#创建bufferself.register_buffer('my_buffer',torch.Tensor([1,2,3]))self.a=torch.Tensor([1])self.param1=torch.nn.Parameter(torch.Tensor([1,3,5,7,9]))#方法1 创建的parameterparam2=torch.nn.Parameter(torch.Tensor([2,4,6,8,0]))self.register_parameter('param2',param2)self.l=torch.nn.Linear(1,10)def forward(self,x):passn=net()for i in n.state_dict():print(i,n.state_dict()[i]) print('*'*10) for i in n.parameters():print(i) print('*'*10) for i in n.buffers():print(i) print('*'*10)''' param1 tensor([1., 3., 5., 7., 9.]) param2 tensor([2., 4., 6., 8., 0.]) my_buffer tensor([1., 2., 3.]) l.weight tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]]) l.bias tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732]) ********** Parameter containing: tensor([1., 3., 5., 7., 9.], requires_grad=True) Parameter containing: tensor([2., 4., 6., 8., 0.], requires_grad=True) Parameter containing: tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]], requires_grad=True) Parameter containing: tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732], requires_grad=True) ********** tensor([1., 2., 3.]) ********** '''
总结
以上是生活随笔为你收集整理的pytorch笔记 pytorch模型中的parameter与buffer的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: pytorch学习笔记 torchnn.
- 下一篇: 文巾解题 16. 最接近的三数之和