当前位置: 首页 > article >正文

【CNN训练梯度裁剪】

  • 梯度裁剪
    • 参考

      • 梯度剪裁
      • PyTorch使用Tricks:梯度裁剪-防止梯度爆炸
      • 【Pytorch】梯度裁剪——torch.nn.utils.clip_grad_norm_的原理及计算过程
    • 方法

      梯度裁剪(Gradient Clipping)是一种防止梯度爆炸或梯度消失的优化技术,它可以在反向传播过程中对梯度进行缩放或截断,使其保持在一个合理的范围内。梯度裁剪有两种常见的方法:

      1. 按照梯度的绝对值进行裁剪,即如果梯度的绝对值超过了一个阈值,就将其设置为该阈值的符号乘以该阈值。
      2. 按照梯度的范数进行裁剪,即如果梯度的范数超过了一个阈值,就将其按比例缩小,使其范数等于该阈值。例如,如果阈值为1,那么梯度的范数就是1。
    • pytorch实现

      import torch.nn as nn
       
      outputs = model(data)
      loss= loss_fn(outputs, target)
      optimizer.zero_grad()
      loss.backward()
      # 方法一
      torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
      # 方法二
      nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)	#! 在梯度更新前裁剪
      optimizer.step()
      
      # 获取梯度的范数
      # 对于模型的每个参数,计算其梯度的L2范数
      for param in model.parameters():
          grad_norm = torch.norm(param.grad, p=2)
          print(grad_norm)
      

http://www.kler.cn/a/292029.html

相关文章:

  • vue+svg圆形进度条组件
  • Spring Boot 启动时修改上下文
  • ISP是什么?
  • 狼蛛F87Pro键盘常用快捷键的使用说明
  • 【链路层】空口数据包详解(4):数据物理通道协议数据单元(PDU)
  • 从电动汽车到车载充电器:LM317LBDR2G 线性稳压器在汽车中的多场景应用
  • HarmonyOS $r访问资源
  • MyPrint打印设计器(九)svg篇-圆
  • 【计算机视觉前沿研究 热点 顶会】ECCV 2024中Mamba有关的论文
  • C# NX二次开发-获取体全部面
  • Circuitjs 在线电路模拟器使用指南
  • tomcat日志显示中文乱码的方法解决
  • MySQL基础:索引
  • ESRI ArcGIS Pro 3.1.5新功能及安装教程和下载
  • python常用库学习-Matplotlib使用
  • Redis——BigKey
  • 【MySQL】主键优化原理篇——【数据组织方式&主键顺序插入&主键乱序插入&页分裂&页分裂】
  • 【Python机器学习】核心数、进程、线程、超线程、L1、L2、L3级缓存
  • 空气质量题数据处理与分析
  • 在元神操作系统中获取动态内存
  • HarmonyOS开发实战( Beta5版)AOT编译使用指南
  • 电脑知识:如何恢复 Word、媒体和存档文件?
  • 在线演示文稿应用PPTist本地化部署并实现无公网IP远程编辑PPT
  • 客户端时间和服务器时间的区别
  • GO环境安装
  • AIoTedge IoT平台替代网关、PLC和HMI,实现智慧农业大棚控制