【CNN训练梯度裁剪】
-
梯度裁剪
-
参考
- 梯度剪裁
- PyTorch使用Tricks:梯度裁剪-防止梯度爆炸
- 【Pytorch】梯度裁剪——torch.nn.utils.clip_grad_norm_的原理及计算过程
-
方法
梯度裁剪(Gradient Clipping)是一种防止梯度爆炸或梯度消失的优化技术,它可以在反向传播过程中对梯度进行缩放或截断,使其保持在一个合理的范围内。梯度裁剪有两种常见的方法:
- 按照梯度的绝对值进行裁剪,即如果梯度的绝对值超过了一个阈值,就将其设置为该阈值的符号乘以该阈值。
- 按照梯度的范数进行裁剪,即如果梯度的范数超过了一个阈值,就将其按比例缩小,使其范数等于该阈值。例如,如果阈值为1,那么梯度的范数就是1。
-
pytorch实现
import torch.nn as nn outputs = model(data) loss= loss_fn(outputs, target) optimizer.zero_grad() loss.backward() # 方法一 torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5) # 方法二 nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) #! 在梯度更新前裁剪 optimizer.step() # 获取梯度的范数 # 对于模型的每个参数,计算其梯度的L2范数 for param in model.parameters(): grad_norm = torch.norm(param.grad, p=2) print(grad_norm)
-