欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

pytorch maxout实现

发布时间:2025/4/16 编程问答 53 豆豆
生活随笔 收集整理的这篇文章主要介绍了 pytorch maxout实现 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

简述

看了半天,在网上没有看到pytorch关于maxout的实现。(虽然看到其他的模型也是可以用的,但是为了更好的复现论文,这里还是打算实现下)。

(不一定保证完全正确,估计很快pytorch就会自己更新,对应的maxout激活函数了吧?我看到github上好像有对应的issue了都)

maxout的原理也很简单:简单来说,就是多个线性函数的组合。然后在每个定义域上都取数值最大的那个线性函数,看起来就是折很多次的折线。(初中数学emmm)

实现

from torch.nn import init import torch.nn.functional as F from torch._jit_internal import weak_module, weak_script_method from torch.nn.parameter import Parameter import math@weak_module class Maxout(nn.Module):__constants__ = ['bias']def __init__(self, in_features, out_features, pieces, bias=True):super(Maxout, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.pieces = piecesself.weight = Parameter(torch.Tensor(pieces, out_features, in_features))if bias:self.bias = Parameter(torch.Tensor(pieces, out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)init.uniform_(self.bias, -bound, bound)@weak_script_methoddef forward(self, input):output = input.matmul(self.weight.permute(0, 2, 1)).permute((1, 0, 2)) + self.biasoutput = torch.max(output, dim=1)[0]return output

如果也喜欢研究源码的小伙伴就会发现了,我就是在原来的Linear()的源码基础上多改进了一个维度而已。

技巧还是在那个维度切换那里,其他都没啥,用自己这个试了下,效果还行(不亏是我,叉腰.jpg)

调用的方式也很简单,就是平常写的那些nn.Linear() 的方式很像。

就是跟nn.Linear一样的用啊。pieces的概念,就是pieces个函数,在定义上每个点上,取最大的那个函数对应的数值,作为整个函数的最大值。

nn.Sequential(Maxout(in_c, out_c, pieces) )

总结

以上是生活随笔为你收集整理的pytorch maxout实现的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。