当前位置: 首页 > article >正文

[pytorch] 训练节省显存的技巧

参考资料:

  • https://blog.csdn.net/Wenyuanbo/article/details/119107466

实践如下:

1. AdamW 优化器比 SGD 优化器更耗显存

  • 有时候能够降几GB,但是有时候不怎么降

2. 删除多余无用变量

del 功能是彻底删除一个变量,要再使用必须重新创建,注意 del 删除的是一个变量而不是从内存中删除一个数据,这个数据有可能也被别的变量在引用,实现方法很简单,比如:

  • 残差链接是可以删除的嘛?这个我没有实践过,我怕删了计算不了梯度
 def forward(self, x):
 
	input_ = x
	x = F.relu_(self.conv1(x) + input_)
	x = F.relu_(self.conv2(x) + input_)
	x = F.relu_(self.conv3(x) + input_)
	
	del input_  # 删除变量 input_

	x = self.conv4(x)  # 输出层
	return x

  • 训练完之后删除loss释放计算图,不然loss和下一个batch重叠在一起,很消耗显存,实测可以下降9GB左右显存,可能是我的模型有点大。
total_loss /= self.acc_step
total_loss.backward()
if i % self.acc_step == 0 or i == len(train_loader):  # i starts from 1
    if self.args.use_amp:
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
    else:
        self.optimizer.step()
        self.optimizer.zero_grad()

del x, y, out, total_loss, losses   #  batch 可以删掉,同时 loss也可以删掉

3. 周期性地清空缓存

使用 torch.cuda.empty_cache() 释放显存。

这个我也没有试验过。

4. 清空计算图

loss.backward 是计算梯度, optimzer.step 是更新梯度,但是计算图没有释放,可以使用zero_grad释放计算图。

myNet.zero_grad()  # 模型参数梯度清零
optimizer.zero_grad()  # 优化器参数梯度清零

5. 半精度学习 / 混合精度学习

因为偷懒,使用的是 pytorch 自带的 scaler,发现其实没什么用

if self.args.use_amp:
            self.scaler = GradScaler()

if self.args.use_amp:
   with autocast():
         out = self.model(x)
         total_loss, losses = self.model.loss(out, y)
         total_loss = self.scaler.scale(total_loss)
 else:
     out = self.model(x)
     total_loss, losses = self.model.loss(out, y) 

total_loss /= self.acc_step
 total_loss.backward()
 if i % self.acc_step == 0 or i == len(train_loader):  # i starts from 1
     if self.args.use_amp:
         self.scaler.step(self.optimizer)
         self.scaler.update()
         self.optimizer.zero_grad()
     else:
         self.optimizer.step()
         self.optimizer.zero_grad()

http://www.kler.cn/news/304439.html

相关文章:

  • Kizuna AI——AI驱动虚拟偶像,AI分析观众的反应和互动,应用娱乐、直播和广告行业
  • Linux(RedHat或CentOS)下如何开启telnet服务
  • 【时时三省】(C语言基础)指针进阶 例题7
  • SQLITE3数据库实现信息的增删改查
  • ensp—路由过滤、路由引入、路由策略
  • 【基础知识复习 - 随机练习题】
  • 1935. 公交换乘(transfer)
  • 常用环境部署(二十)——docker部署OpenProject
  • 基于华为云服务器的网页部署
  • 【Android】使用和风天气API获取天气数据吧!(天气预报系列之一)
  • ARCGIS PRO DSK MapTool
  • 使用Azure Devops Pipeline将Docker应用部署到你的Raspberry Pi上
  • 【Hadoop|MapReduce篇】Hadoop序列化概述
  • LabVIEW FIFO详解
  • 分享六款小众宝藏软件,建议收藏!
  • golang os.Eixt的介绍和使用
  • 【C++】vector常见用法
  • 数字化大屏解决方案 - GoView
  • 如何通俗易懂的解释TON的智能合约
  • DolphinScheduler应用实战笔记
  • ROS2 Control controller_interface说明
  • 论文阅读笔记: DINOv2: Learning Robust Visual Features without Supervision
  • LOAM学习
  • camouflaged object detection中的decoder最核心的作用
  • Amazon EC2:灵活、可扩展的云计算解决方案
  • Flutter iOS混淆打包
  • 安卓13禁止声音调节对话框 删除音量调节对话框弹出 屏蔽音量对话框 android13
  • springcloud OpenFeign 日志打印功能
  • java项目之中药实验管理系统(源码+文档)
  • Linux 入门:简单的基础操作