[pytorch] 训练节省显存的技巧
参考资料:
- https://blog.csdn.net/Wenyuanbo/article/details/119107466
实践如下:
1. AdamW 优化器比 SGD 优化器更耗显存
- 有时候能够降几GB,但是有时候不怎么降
2. 删除多余无用变量
del 功能是彻底删除一个变量,要再使用必须重新创建,注意 del 删除的是一个变量而不是从内存中删除一个数据,这个数据有可能也被别的变量在引用,实现方法很简单,比如:
- 残差链接是可以删除的嘛?这个我没有实践过,我怕删了计算不了梯度
def forward(self, x):
input_ = x
x = F.relu_(self.conv1(x) + input_)
x = F.relu_(self.conv2(x) + input_)
x = F.relu_(self.conv3(x) + input_)
del input_ # 删除变量 input_
x = self.conv4(x) # 输出层
return x
- 训练完之后删除loss释放计算图,不然loss和下一个batch重叠在一起,很消耗显存,实测可以下降9GB左右显存,可能是我的模型有点大。
total_loss /= self.acc_step
total_loss.backward()
if i % self.acc_step == 0 or i == len(train_loader): # i starts from 1
if self.args.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
else:
self.optimizer.step()
self.optimizer.zero_grad()
del x, y, out, total_loss, losses # batch 可以删掉,同时 loss也可以删掉
3. 周期性地清空缓存
使用 torch.cuda.empty_cache()
释放显存。
这个我也没有试验过。
4. 清空计算图
loss.backward
是计算梯度, optimzer.step
是更新梯度,但是计算图没有释放,可以使用zero_grad释放计算图。
myNet.zero_grad() # 模型参数梯度清零
optimizer.zero_grad() # 优化器参数梯度清零
5. 半精度学习 / 混合精度学习
因为偷懒,使用的是 pytorch 自带的 scaler,发现其实没什么用
if self.args.use_amp:
self.scaler = GradScaler()
if self.args.use_amp:
with autocast():
out = self.model(x)
total_loss, losses = self.model.loss(out, y)
total_loss = self.scaler.scale(total_loss)
else:
out = self.model(x)
total_loss, losses = self.model.loss(out, y)
total_loss /= self.acc_step
total_loss.backward()
if i % self.acc_step == 0 or i == len(train_loader): # i starts from 1
if self.args.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
else:
self.optimizer.step()
self.optimizer.zero_grad()