深度学习常见指标——FLOPs(搭配代码食用)
文章目录
- 1.前言
- 2.简介
- 2.1MACs
- 2.2两者关系
- 3.计算公式
- 3.1 卷积层的FLOPs计算公式
- 3.2 全连接层的FLOPs计算公式
- 3.3 RELU层FLOPs计算
- 4.库thop说明
- 5.例子:resnet18的FLOPs
1.前言
首先明确一个概念:FLOPS和FLOPs不一样
- FLOPS是处理器性能的衡量指标,:“每秒所执行的浮点运算次数”的缩写
- FLOPs是算法复杂度的衡量指标,“浮点运算次数”的缩写,s代表的是复数
2.简介
FLOPs(Floating point operations)浮点运算次数
是衡量算法复杂度的指标(包括加减乘除等所有浮点运算)
2.1MACs
MACs(Multiply accumulate operations)
:乘法和加法运算的次数,一个乘法与加法算一个MACs
深度学习的很多情况下,尤其是卷积神经网络,通常是先进行元素乘法,然后将结果累加起来。因此,MACS可以很好地反映卷积神经网络中的计算量。
2.2两者关系
Flops=2*MACs
解释:每进行一次乘法,一般都会跟上一次加法,例如10次乘法,会伴随(10-1)次加法。
两者只相差1,因此近似看作两倍关系。
3.计算公式
3.1 卷积层的FLOPs计算公式
flops=2×(output_h × output_w)× (kernel_size)×in_channel×output_channel
flops=2×(62 × 62 )×(3×3) × 3 × 32=6642432
代码运行:
import torch
import torch.nn as nn
from thop import profile
if __name__ == '__main__':
model=nn.Sequential(nn.Conv2d(3,32,kernel_size=3,stride=1,padding=0))
# 输入数据
inputs=torch.randn(1,3,64,64)
macs,params=profile(model,inputs=(inputs,),verbose=True)
flops=2*macs
print('flops: ',flops)
print('params: ',params)
3.2 全连接层的FLOPs计算公式
代码运行
if __name__ == '__main__':
model=nn.Sequential(nn.Linear(256,10))
# 输入数据
inputs=torch.randn(1,3,256)
# 使用thop计算参数量
macs,params=profile(model,inputs=(inputs,),verbose=False)
flops=2*macs
print('flops: ',flops)
print('params: ',params)
结果如下:
15360=2×3×256×10
3.3 RELU层FLOPs计算
- 在计算 FLOPs 时,通常只考虑实际的浮点运算,如乘法和加法,因此 ReLU 的 FLOPs 通常被视为 0。
- 因为这些操作(例如判断一个数是否大于零)并不涉及浮点运算。ReLU 中的比较操作通常不计)FLOPS.
4.库thop说明
thop 是一个用于计算PyTorch模型FLOPs(每秒浮点运算次数)和参数数量的Python库。它可以帮助开发者理解和优化深度学习模型的计算复杂度,对于模型的性能分析、优化以及资源需求评估等方面非常有用。
参数说明:
- model:传入的模型
- inputs:输入(格式一般是元组)
- verbose:是否输出日志信息(True输出,False不输出)
返回数据:
- MACs
- FLOPs
- Params:参数所占空间大小
5.例子:resnet18的FLOPs
import torch
from torchvision import models
import thop
if __name__ == '__main__':
#resnet18
model=models.resnet18(weights=None)
#输入数据
inputs=torch.randn(1,3,224,224)
#计算
MACs,Params=thop.profile(model,inputs=(inputs,),verbose=True)
FLOPs=MACs*2
MACs,FLOPs,Params=thop.clever_format([MACs,FLOPs,Params],'%.3f')
print(f'MACs:{MACs}')
print(f'FLOPs:{FLOPs}')
print(f'Params:{Params}')
print(model)