tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...
地址:http://bindog.github.io/
01
背景
前几天看到知乎上的文章FLOPs与模型推理速度[1],文中提到一个比较耗时又占显存的pointwise操作x * sigmoid(x),这实际上是swish activation[2];暂且不提它背后的争议,本文主要想从这个结构入手来优化它的显存占用以及耗时,并讨论更广泛的训练时显存优化技术。02
反向传播是如何工作的?
要分析清楚swish activation为什么会比较占显存,我们首先需要搞清楚反向传播是如何工作的,或者更进一步说,现有的自动求导框架是如何求出梯度的。先明确一点,所谓自动求导框架实际上是“半自动”的:它并非直接求出一个复杂函数导数的解析形式,而是通过构建计算图和预先写好的基础函数的求导规则,结合链式求导法则实现的自动求导。以swish acivation为例进行说明,其表达式为f(x) = x * sigmoid(x),通过简单的数学推导得到其梯度的解析式为f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把这个结果放一边,看看自动求导框架是如何一步步求出这个结果的,画出计算图如下:除了计算图以外,我们还需要定义几个基本函数的求导规则,在这个例子里涉及两个函数,一个是乘法,另一个是sigmoid函数(实际上sigmoid也是由几个基本函数构成的,这里我们将其视为一个整体)f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x) # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))03
显存被谁吃掉了
先说一个结论,在绝大多数神经网络的训练过程中,显存占用的大头是中间结果,也就是所谓的“特征图”。那我们为什么要保留中间结果呢?当然是为了方便求导啊!还是以swish acivation为例,把它放入神经网络来看,x就是前一层输出的中间结果(特征图)- 在适用乘法的求导规则时,要求我们要事先保留下中间结果x和sigmoid(x),有人可能会说只保留一个x不就可以了吗?sigmoid(x)可以通过计算得出,注意框架定义的乘法及其求导规则是通用规则,乘法的左右两边完全可能是不相关的两个值,所以必须同时保留下来。
- 在对sigmoid函数适用求导规则时,需要存下中间结果x。
04
手动合并OP
那么有没有办法优化呢?当然是可以的,既然我们能用数学公式提前算出swish acivation的梯度,那么直接将其视为一个整体不就好了?无非就是定义一个新的函数和新的求导规则
swish(x) = x * sigmoid(x)# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))这样一来,计算图变成了下面这个样子:
x的梯度可以直接根据新的规则求出,而在新的规则下,我们只需要保留x这一个中间结果即可,sigmoid(x)可以根据x求出。所以说,自动求导框架虽然省事,但是其缺陷也很明显,由于大部分求导规则是面向通用的函数,很难针对特定的场景进行自动优化而导致显存浪费。对swish acivation这样的函数,只能依靠工程师的经验手动的进行优化。需要指出的是,现有的一些框架如TVM和TensorRT也能自动的对某些算子进行融合,进而大大提高计算效率,降低显存消耗,但是这些都属于部署阶段了,而本文讨论的均为训练阶段。类似的优化案例还有inplace-abn[3],针对的是类似BN-ReLU-Conv这样的常见结构组合,如下所示图中的虚线框是需要保留的中间结果,inplace-abn的优化思路是只保留中间结果z,通过反推得到x,然而众所周知ReLU是不可逆的运算,因此inplace-abn将其替换为了Leaky ReLU,计算图变成了如下形式:接下来的事情就是用数学的方式手动求出导数,然后定义成规则即可。对II型,更进一步,直接用的反函数,进行替换即可
虽然推导过程有些复杂,但写出求导公式后,我们只需要将其封装进手写的模块中即可。原论文[4]中的实现表明,采用Inplace-abn后,显存占用最高可下降50%左右,而且由于Leaky ReLU实际效果其实与ReLU非常接近,省下来的显存可以用于提高batch_size,模型训练实际上能从中得到更大收益。
05
还能更进一步吗?
回想前面的优化过程,我们发现其实这是一种典型的时间换空间的做法,虽然模型占用的显存下降了(舍弃了大量中间结果),但是我们定义的求导规则非常复杂,计算步骤明显多于优化前,其根本原因并非是不需要中间结果,而是有办法在求导过程中实时的计算出之前被舍弃掉的中间结果。考虑GPU上显存资源与计算资源的关系,只用较少的计算量和额外的一点计算时间换取宝贵的显存资源,这么做实际上是划算的。如果沿着这个思路更进一步,所有的中间结果都不需要存储了,只需要存最初的输入即可,因为所有的中间结果都可以由输入重新计算得到,然而这个方案显然是不划算的,因为反向传播的过程是“由深入浅”,而计算中间结果的过程是“由浅入深”,二者的方向并不匹配,每当我们需要中间结果时就需要从头再来一遍,这样的计算和时间开销显然是不划算的。如果折中一下呢?这就是OpenAI提出的gradient-checkpoint的思路,在神经网络中间设置若干个检查点(checkpoint),检查点以外的中间结果全部舍弃,反向传播求导数的时间,需要某个中间结果时,从最近的检查点开始计算,这样既节省了显存,又避免了从头计算的繁琐过程;从代码层面来看,原版实现[5]用的是tensorflow,由于是静态图的缘故,需要用到grapheditor等一系列骚操作,而且包含了很多“智能”寻找bottleneck选择为checkpoint的代码,很容易劝退新人。但是如果看一下pytorch的官方实现[6],你会惊讶的发现gradient-checkpoint的核心部分出奇的简单,这也算是动态图以及pytorch的一点小优势吧,当然pytorch版本的实现并不包括智能寻找checkpoint点的功能,需要人为设定。核心代码如下所示:class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") inputs = ctx.saved_tensors # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(inputs) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None, None) + grads注意到最近旷视开源的MegEngine,在PR的时候提到一个亚线性显存优化技术[7],其实就是gradient-checkpoint技术,详情可参考论文Training Deep Nets with Sublinear Memory Cost[8],当然MegEngine肯定在细节上对其进行了一些优化,本文就不展开讨论了。06
CUDA版的swish activation
回到swish activation的优化上来,如果要追求效率的极致提升,下一步考虑的方案应该是手写C++ extension,将计算从python层面转移到C++与CUDA上。如何基于 pytorch写C++扩展,官方文档上有非常详细的教程[9],写法和方式也都比较灵活,可以根据自己的习惯进行选择,这里我们选择利用setuptools的方式进行构建pytorch在用户自定义扩展上也是做了非常多的支持,用户能非常方便的使用pytorch底层定义好的一些类和函数;在写CUDA函数时,pytorch还提供了一个CUDAApplyUtils.cuh头文件,专门用于优化pointwise操作的情况,以减小拷贝和临时存储的显存浪费(用于lambda函数,函数名非常直观,CUDA_tensor_applyN表示操作数的个数,N可以为1,2,3,4,用户还可以指定每个操作数的属性,如只读/读写,针对每对情形都有专门的优化实现)对于swish activation来说,由于全是pointwise操作,利用这个优化技巧可以把显存占用进一步压缩。具体代码可参考swish_optimize[10]简单对比一下以上几种实现在实际场景中(单卡RTX 2070,resnet50, bs=32)的显存占用情况和运行时间(一次forward & 一次backward & 参数更新)- 无优化纯Python:GPU memory=6383MB,time=223ms
- 合并算子(Python):GPU memory=5139MB,time=234ms
- 合并算子(CUDA):GPU memory=5143MB,time=188ms
[1] https://zhuanlan.zhihu.com/p/122943688
[2] https://arxiv.org/abs/1710.05941
[3] https://github.com/mapillary/inplace_abn
[4] https://arxiv.org/pdf/1712.02616.pdf
[5] https://github.com/cybertronai/gradient-checkpointing
[6] https://github.com/pytorch/pytorch/blob/176174a68ba2d36b9a5aaef0943421682ecc66d4/torch/utils/checkpoint.py#L55
[7] https://zhuanlan.zhihu.com/p/138730559
[8] https://arxiv.org/abs/1604.06174
[9] https://pytorch.org/tutorials/advanced/cpp_extension.html
[10] https://github.com/bindog/swish_optimize
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
直播预告
历史文章推荐
【CVPR 2020 Tutorial】如何写好论文和评审(概述)
如何撰写高水平的博士论文?超全论文指导!
北大读博手记:怎样完成自己的博士生涯?非常具有指导性!
太牛逼了!一位中国博士把整个CNN都给可视化了,每个细节看的清清楚楚!
Nature发表牛津博士建议:我希望在读博士之初时就能知道的20件事
沈向洋、华刚:读科研论文的三个层次、四个阶段与十个问题
如何看待2021年秋招算法岗灰飞烟灭?
独家解读 | ExprGAN:基于强度可控的表情编辑
独家解读 | 矩阵视角下的BP算法
独家解读 | Capsule Network深度解读
独家解读 | Fisher信息度量下的对抗攻击
论文解读 | 知识图谱最新研究综述
你的毕业论文过了吗?《如何撰写毕业论文?》
卡尔曼滤波系列——经典卡尔曼滤波推导
分享、点赞、在看,给个三连击呗!
总结
以上是生活随笔为你收集整理的tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: python收入波动告警分析_使用Pyt
- 下一篇: 标题隐藏_头条官方课程没看就想起好标题?