pytorch节省显存小技巧

编程入门 行业动态 更新时间:2024-10-08 08:27:53

pytorch节省<a href=https://www.elefans.com/category/jswz/34/1749167.html style=显存小技巧"/>

pytorch节省显存小技巧

使用pytorch进行文本多分类问题时遇到了显存out of memory的情况,实验了多种方法,主要比较有效的有两种:
1、尽可能使用inplace操作,比如relu可以使用inplace=True
进一步将BN归一化和激活层Relu打包成inplace,在BP的时候再重新计算。
代码与论文参考
mapillary/inplace_abn
efficient_densenet_pytorch
效果:可以减少一半的显存
原理:
在大多数的深度网络的前向传播中,都有BN-Activation-Conv这样的网络结构,就必须要存储归一化的输入和全卷机层的输入。这是有必要的,因为反向传播需要输入计算梯度。通过重写BN的反向传播步骤,使用ABN代替BN-Activation序列,可以不用存储BN的输入(可通过激活函数的输出即全卷机层的输入反推),节省50%的显存。
2、使用float16精度混合计算,利用NVIDIA 的apex,也能减少50%的显存。但是有一些操作不安全如mean、sum等
官方代码参考
NVIDIA apex
3、pytorch1.0提供了模型拆分成2部分在2张卡上运行的方案
pytorch官网多卡例子
4、使用pytorch1.0的checkpoint特性,可以减少90%的显存
ckeckpoint通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算。ckeckpoint不会保存中间的激活参数,而是通过反向传播时重新计算他们。
具体可见我的GitHub一个文本分类项目
我的GitHub

更多推荐

pytorch节省显存小技巧

本文发布于:2024-02-06 12:28:10,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1748991.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:显存   小技巧   节省   pytorch

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!