当前位置: 首页 > article >正文

Pytorch训练时报nan

0. 引言

Pytorch训练时在batch=N时loss为nan。经过断点检查发现在batch=N-1时,网络参数非nan,输出非nan,但梯度为nan,导致网络参数已经全部被更新为nan,遇到这种情况应该如何排查,如何避免?由于导致nan的情况较为繁多,本文给出的不是一个个例的解决方案,而是一种通用的抽象解决方案。

1. 排查

最简单的排查的方式就是检查parameter的参数值:

# model
for name, param in model.named_parameters(recurse=True):
	if not torch.isfinite(param.mean()):
		print(name)

通过该种方法可以打印出网络参数中数值非有限值的参数所在层。

第二种方法是检查parameter的梯度值,该方法需要retain_graph=True (Pytorch默认不保存图结构以节省GPU内存)

# compute loss
loss.backward(retain_graph=True)
# model
for name, param in model.named_parameters(recurse=True):
	if not torch.isfinite(param.grad.mean()):
		print(name)

检查梯度和参数值的方式都是从后往前查(和反向传播的顺序一致),子节点出现问题会导致其根节点必定出现问题,因此优先排查子节点是否是导致nan的原因。

最后提醒一下,如果nan排查成功,别忘了把retain_graph=True给删了,因为这条命令占用额外的GPU内存。

2. 规避

在这里介绍的方法是基于Pytorch 1.13的,Pytorch 2.x的用户也不想要担心,因为本教程中设置的参数在Pytorch 2.x里面已经设为默认参数,完全兼容。

# compute loss
# optimizer, model
clip_grad = 1.0 # maximum value to clip grad_norm
try:
	nn.utils.clip_grad_norm_(model.parameters(), clip_grad, norm_type=2, error_if_nonfinite=True) # 遇到nonfinite的梯度报错
	optimizer.step()
except:
	print("nan detected in grad, skip batch")
	optimizer.zero_grad()  # 所有梯度置0,保证下一个batch的正常训练
	continue  # 跳过这个batch的训练

这个代码的思想就是利用clip_grad_norm_自带的梯度检查功能在反向传播前对model的每个参数梯度进行检查,如若出现梯度异常值,则跳过batch(且不会对网络进行梯度更新)。需要的注意的是,optimizer.zero_grad()除了在本代码中出现,应该在主循环里面也另外有一个,但是此处省略了。


http://www.kler.cn/a/384771.html

相关文章:

  • 如何对接低价折扣相对稳定电影票渠道?
  • 块存储、文件存储和对象存储详细介绍
  • 核心数据类型转换
  • 汽车广告常见特效处理有哪些?
  • Linux(ubuntu) 部署xinference
  • C语言复习第7章 自定义类型(结构体+位段+枚举+联合体)
  • laravel chunkById 分块查询 使用时的问题
  • Spring Cloud Bus快速入门Demo
  • 第九周预习报告
  • qt QItemSelectionModel详解
  • 多个服务器共享同一个Redis Cluster集群,并且可以使用Redisson分布式锁
  • Git LFS
  • 专业130+总400+武汉理工大学855信号与系统考研经验电子信息与通信工程,真题,大纲,参考书。
  • 内置函数【MySQL】
  • 生产环境中使用:带有核函数的 SVM 处理非线性问题
  • Unity 的 WebGL 构建中资源图片访问方式
  • 人工智能:重塑生活与工作的神奇力量
  • WebRTC REMB算法
  • AIGC--如何在内容创作中合理使用AI生成工具?
  • H.265流媒体播放器EasyPlayer.js网页web无插件播放器:如何优化加载速度
  • 使用 Java 实现邮件发送功能
  • Matlab实现鲸鱼优化算法优化随机森林算法模型 (WOA-RF)(附源码)
  • 23isctf
  • tomcat 开启远程debug模式
  • vue组件获取props中的数据并绑定到form表单 el-form-item的v-model中方法
  • Django-------重写User模型