深度学习模型训练过程的加速方法介绍
一、 简介
在深度学习中,神经网络的训练过程往往是最耗费时间的。本文介绍一些实用的技巧,实现代码的提速,具体提速效果可由读者亲自尝试。
二、 方法
1 训练样本随机抽样
此方法对点云类的训练样本有效(比如三维形体数据)。举个例子,存在一个三维形体数据集,每个三维形体由10000个点来描述。那么在训练过此中我们可以在每次计算Loss函数时候,只随机选取10000个点中10%开展训练。这样不仅可以降低显存占用,还可以在显存一定的条件下增加batch_size,使训练过程稳定。不过,对于图像栅格数据,此类方法并不适用。
2 在不需要计算梯度的地方使用torch.no_grad()
对于测试集,不需要作反向传播优化神经网络参数,因此计算Loss函数的时候不需要记录梯度信息。在测试集代码中,可以使用torch.no_grad()关闭梯度计算
with torch.no_grad():
特别地,对于目标函数中用到自动求导的问题(比如PINN),可以禁用梯度计算create_graph=False
以提升速度并降低显存占用
with torch.no_grad(): # 禁用梯度计算以提升速度
grad_u_xyz = torch.autograd.grad(
outputs=u_xyz,
inputs=xyz_field,
grad_outputs=ones,
create_graph=False,
retain_graph=True,
allow_unused=False
)[0]
3 混合精度训练
混合精度训练可以在一些情况下自动将float32数值精度降低为float16以显著提高计算速度并降低显存占用。使用方法很简单,只需要在原先的代码基础上做小幅度修改即可,首先在epoch循环前初始GradScaler
scaler = GradScaler()
然后使autocast
包装向前传播过此,也就算计算Loss的过程
with autocast(device_type='cuda'):
loss = loss_function(dead, data)
然后反向传播和优化器梯度下降使用GradScaler
包装
# 反向传播:使用 scaler 缩放损失值
scaler.scale(loss).backward()
# 优化器更新:使用 scaler 缩放梯度
scaler.step(optimizer)
# 更新 scaler 的状态
scaler.update()
4 提高数据传输的速度
这一点需要尤其注意。有时候我们以为把代码放在GPU上计算可以提高计算速度,但是往往数据从CPU内存拷贝到GPU的时间远远大于计算所需时间,这样导致训练过此中大部分时间被用来加载数据,严重浪费算力。
在Pytorch中Dataset
用于存储数据集,并依托DataLoader
将数据集拆封为多个batch,“投喂”给模型,开展训练。考虑到计算机的储存主要由三个部件构成硬盘、内存和显存,数据传递流程如下
如果每个batch都要完成上述流程,那么无疑是极其浪费时间的。如果能将数据全部存储在内存上,每当调用Loss函数的时候,使用.cuda()
命令将其转移到显存上,那么将会节约一大部分时间。进一步地,如果能将数据直接全部存储在显存上,那么就可以避免数据传输的耗时,可以最大限度提速GPU并行计算的效率。
如果将数据全部存储在内存上,那么对内存的要求会比较高,拼一拼市面上内存还是可以满足要求的;但是如果将数据全部存储在显存上,那么对显存的要求会很高,市面上好的显卡也就80G显存,所以此方法还需根据具体实际条件选用。
5 减少损失函数中不必要的耗时计算
具体地,不要调用hasattr(model, ‘latent_vectors’)这种耗时的函数;
不要使用for循环,而应当并行化;对于如何将for循环并行化,完全可以请教deepseek或chatGPT, 让它们给出具体的调优方案
尽量少用.item()函数;
避免不必要的数据传输。
6 提升硬件水平
以本人的经历为例,同样的代码,使用A800中高端显卡的计算速度要比A5000中低端显卡的速度快1倍多。所以,提升硬件水平不失为最后的一种手段。
三、 性能监测方法
需要优化代码性能的时候我们可以使用pytorch自带的性能检测方法,把需要监测的代码片段放入如下语句中:
with profiler.profile(use_device = 'cuda', profile_memory=True, with_stack=True) as prof:
使用如下语句打印性能数据
print(prof.key_averages().table(sort_by="cuda_time_total"))
这样,在训练过程中将以GPU耗时由大到小的顺序列出调用函数名称。开发人员可以分析性能瓶颈,以实现进一步调优。