欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

pytorch笔记 pytorch模型中的parameter与buffer

发布时间:2025/4/5 编程问答 51 豆豆
生活随笔 收集整理的这篇文章主要介绍了 pytorch笔记 pytorch模型中的parameter与buffer 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

1 模型的两种参数

在 Pytorch 中一种模型保存和加载的方式如下:(具体见pytorch模型的保存与加载_刘文巾的博客-CSDN博客)

#save torch.save(net.state_dict(),PATH)#load model=MyModel(*args,**kwargs) model.load_state_dict(torch.load(PATH)) model.eval

模型保存的是 net.state_dict() 的返回对象。

net.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数

上例模型中的参数就是线性层的 weight 和 bias.

 

模型中需要保存下来的参数包括两种:

  • 一种是反向传播需要被optimizer更新的,称之为 parameter
  • 一种是反向传播不需要被optimizer更新,称之为 buffer

第一种参数我们可以通过 model.parameters() 返回;

第二种参数我们可以通过 model.buffers() 返回。

因为我们的模型保存的是 state_dict 返回的 OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict

2 Parameter

 

Parameter参数有两种创建方式:

  • 我们可以直接将模型的成员变量self.xxx通过nn.Parameter() 创建,会自动注册到parameters中,可以通过model.parameters() 返回,并且这样创建的参数会自动保存到OrderDict中去;
  • 通过nn.Parameter() 创建普通Parameter对象,不作为模型的成员变量,然后将Parameter对象通过register_parameter()进行注册,可以通model.parameters() 返回,注册后的参数也会自动保存到OrderDict中去;
  • 像我们前面的nn.Conv1d,nn.Linear,nn.RNN等模型,里面的权重参数等会被自动认为是Parameter 参数

    3 buffer

    buffer参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,可以通model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。

    4 为什么要注册​​​​​​​

    为什么不直接将不需要进行参数修改的变量作为模型类的成员变量就好了,还要进行注册?

  • 不进行注册,参数不能保存到 OrderDict,也就无法进行保存
  • 模型进行参数在CPU和GPU移动时, 执行 model.to(device) ,注册后的参数可以自动进行设备移动
  • 5 实例说明

    import torch class net(torch.nn.Module):def __init__(self):super(net,self).__init__()#创建bufferself.register_buffer('my_buffer',torch.Tensor([1,2,3]))self.a=torch.Tensor([1])self.param1=torch.nn.Parameter(torch.Tensor([1,3,5,7,9]))#方法1 创建的parameterparam2=torch.nn.Parameter(torch.Tensor([2,4,6,8,0]))self.register_parameter('param2',param2)self.l=torch.nn.Linear(1,10)def forward(self,x):passn=net()for i in n.state_dict():print(i,n.state_dict()[i]) print('*'*10) for i in n.parameters():print(i) print('*'*10) for i in n.buffers():print(i) print('*'*10)''' param1 tensor([1., 3., 5., 7., 9.]) param2 tensor([2., 4., 6., 8., 0.]) my_buffer tensor([1., 2., 3.]) l.weight tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]]) l.bias tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732]) ********** Parameter containing: tensor([1., 3., 5., 7., 9.], requires_grad=True) Parameter containing: tensor([2., 4., 6., 8., 0.], requires_grad=True) Parameter containing: tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]], requires_grad=True) Parameter containing: tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732], requires_grad=True) ********** tensor([1., 2., 3.]) ********** '''

     

    总结

    以上是生活随笔为你收集整理的pytorch笔记 pytorch模型中的parameter与buffer的全部内容,希望文章能够帮你解决所遇到的问题。

    如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。