当前位置: 首页 > article >正文

Pytorch是如何做显存管理的

  参考资料:

  GPT的回答

  自己的实验结果

  之前自己在用Pytorch跑模型训练的时候产生了如下一系列问题:1)Pytorch使用的cuda显存什么时候释放 2)什么时候会导致显存堆积 3)如何监控显存的使用。经过查找资料后找到了这些问题的答案,现在记录在此:

  1. Pytorch使用的cuda显存什么时候释放?

  我给出的答案很坑,即一般是不会主动释放的。在跑训练代码的时候,显存占用的增长主要有四个地方:1)模型的加载 2)数据的加载 3)模型的前向 4)模型的反向传播。由于训练往往是一个循环,显存的占用往往在模型的反向传播之后,即调用各种形式的backwards()方法之后达到峰值。在到达这个峰值之后,如果你在nvidia-smi中观察,这个峰值却往往是不会任何回落的,这就是Pytorch为本(训练)应用程序所保存/预留的所有显存,此后不管这个应用程序内部怎么造,哪怕这个应用程序在运行的末尾使用del删掉所有的引用,只要程序不调用torch.cuda.empty_cache方法,cuda的显存是永远都不会被释放的。

  此时大家就知道了,Pytorch程序的显存是可以通过删引用,然后empty_cache手动释放的,就好像内存领域的gc.collect() ?我的建议是一般不要使用empty_cache,因为这个方法基本百害而无一利,只会拖慢运行的速度,如果你的显存实在不够用,可以尝试把模型在cpu和cuda上来回腾挪(这么做好像更标准?有一个术语叫cpu_offload?),而不是del之后每次empty_cache然后重新加载。

  2. 什么时候会导致显存堆积

  如果你发现你的训练过程发生了显存的堆积,则很可能是你的程序保留了过多不同cuda tensor的引用,因此一定要慎用 torch.clone(),在大多数时候,只要你不使用过多的torch.clone(),或者用列表保存一堆tensor,显存都是不会累计的,毕竟模型就那么大,模型前向的特征图就那么大,模型权重的梯度就那么大,这些东西占用的显存在进行训练的时候都是一次性占用提升的,根本训练循环的时候也不会发生任何累计的。

  如果你想模拟显存堆积,一个最简单的场景就是for i in range(10000): A.append(torch.tensor(1).to("cuda")),即用一个累加列表不断保存cuda tensor。读者可试一试,先del然后nvidia-smi看显存有没有释放(给其他程序),然后empty_cache看显存有没有释放。

  3. 如何监控显存的使用  

  不是很想研究这个东西,读者可以参见Pytorch的文档:https://pytorch.org/docs/stable/cuda.html#memory-management

  我猜其中最常用的应该是memory_summary这个方法。而其实英伟达官方主推的监控和优化工具应该是nvidia nsight,这个东西应该是需要图形化页面支持的,没研究过,感兴趣的读者可以研究一下。


http://www.kler.cn/news/306947.html

相关文章:

  • qmt量化交易策略小白学习笔记第64期【qmt编程之获取获取期权全推数据--code_list全推tick数据】
  • 鸿蒙媒体开发系列01——资源分类访问
  • 移情别恋c++ ദ്ദി˶ー̀֊ー́ ) ——13.mapset
  • 【springboot】整合spring security 和 JWT
  • Vue接入高德地图并实现基本的路线规划功能
  • Redis基础,常用操作命令,主从复制,一主两从,事务数据库操作
  • day01 - Java基础语法
  • [Golang] Sync
  • HarmonyOS开发之全局状态管理
  • 天融信把桌面explorer.exe删了,导致开机之后无windows桌面,只能看到鼠标解决方法
  • C++基础面试题 | 什么是C++中的虚继承?
  • LabVIEW机动车动态制动性能校准系统
  • spring项目中如何通过redis的setnx实现互斥锁解决缓存缓存击穿问题
  • [项目][WebServer][HttpServer]详细讲解
  • 一码空传临时网盘PHP源码,支持提取码功能
  • 数据中台进化为数据飞轮的必要
  • 【笔记】自动驾驶预测与决策规划_Part2_基于模型的预测方法
  • 初学Linux(学习笔记)
  • Vue.js入门系列(二十九):深入理解编程式路由导航、路由组件缓存与路由守卫
  • 【C++】入门基础(下)
  • Java项目基于docker 部署配置
  • 关于新版本 tidb dashboard API 调用说明
  • 评价类——熵权法(Entropy Weight Method, EWM),完全客观评价
  • ansible安全优化篇
  • 在深圳停车场我居然能看到很漂亮的瓦房
  • 707. 设计链表
  • SQL,从每组中的 json 字段中提取唯一值
  • 鸿蒙开发基础
  • Rust Web开发框架对比:Warp与Actix-web
  • SpringBoot + MySQL + MyBatis 实操示例教学