当前位置: 首页 > article >正文

神经网络参数量和运算量的计算- 基于deepspeed库和thop库函数

引言

最近需要对神经网络的参数量和运算量进行统计。找到一个基于deepspeed库函数计算参数量和运算量的例子。而我之前一直用thop库函数来计算。

看到有一篇勘误博文写道使用thops库得到的运算量是MACs (Multiply ACcumulate operations,乘加累积操作次数),而很多其他文章提到的还是FLOPs(Floating Point Operations,浮点运算次数)。
Pytorch: 采用thop库正确计算模型计算量FLOPs和模型参数Params 【误区更正】
因此对这两种方法进行测试,来验证thop库函数得到的运算量到底是MACs还是Flops。

1 使用deepspeed库函数计算参数量和运算量

对于deepspeed库的安装就不多介绍了,对于window系统,deepspeed的最新版本可以直接通过pip下载。不需要像以前一样安装过程一把辛酸泪。(2025.2.3)

win10上安装看一下文档:
链接: windows系统安装deepspeed说明文档

以下以resnet18为例子

import sys
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
torch.backends.cudnn.deterministic = True
import torchvision.models as models

def main(argv):
    device = "cuda:0"
    net = models.resnet18()
    net.to(device).eval()
    width, height = 224, 224
    flops, macs, params = get_model_profile(net, (1,3,width,height))
    print("params: ", params)
    print("flops: ", flops)
    print("macs: ", macs)
if __name__ == "__main__":
    main(sys.argv)

结果如下:
打印了每一层的运算量和参数量:
在这里插入图片描述
最后打印的结果如下:
在这里插入图片描述

2 使用thop库函数计算运算量和参数量

import torch
from thop import profile
from thop import clever_format
import torchvision.models as models

# 假设我们有一个预训练的模型
model = models.resnet18()
model.eval()

# 使用thop分析模型的运算量和参数量
input = torch.randn(1, 3, 224, 224)  # 随机生成一个输入张量,这个尺寸应该与模型输入的尺寸相匹配
MACs, params = profile(model, inputs=(input,))

# 将结果转换为更易于阅读的格式
MACs, params = clever_format([MACs, params], '%.3f')

print(f"运算量:{MACs}, 参数量:{params}")

在这里插入图片描述

3.结论

从以上两种方法对于ResNet-18的运算量的比较可以得知,
deepspeed库统计ResNet-18的运算量FLOPs为3.64G,MACs为1.81G。
thop库统计 ResNet-18的运算量为1.824G,这个数值上更接近deepspeed库的MACs或者是FLOPs/2。

所以 thop库获得的运算量更可能是MACs而不是Flops


http://www.kler.cn/a/531529.html

相关文章:

  • DeepSeek R1 简易指南:架构、本地部署和硬件要求
  • nodejs:express + js-mdict 网页查询英汉词典,能播放声音
  • 本地部署DeepSeek-R1保姆级教程
  • 2025年1月22日(网络编程 udp)
  • 前端学习-事件委托(三十)
  • 本地搭建deepseek-r1
  • 论文paper(更新...)
  • Apache Hudi数据湖技术应用在网络打车系统中的系统架构设计、软硬件配置、软件技术栈、具体实现流程和关键代码
  • P7497 四方喝彩 Solution
  • Linux+Docer 容器化部署之 Shell 语法入门篇 【Shell 循环类型】
  • Ollama教程:轻松上手本地大语言模型部署
  • linux库函数 gettimeofday() localtime的概念和使用案例
  • kamailio源文件modules.lst的内容解释
  • 通信方式、点对点通信、集合通信
  • SpringBoot中关于knife4j 中的一些相关注解
  • 鸢尾花书《编程不难》01---基本介绍和Jupyterlab的上手
  • 无人机PX4飞控 | PX4源码添加自定义uORB消息并保存到日志
  • Codeforces Round 1002 (Div. 2)(A-D)
  • FBX SDK的使用:读取Mesh
  • centos stream 9 安装 libstdc++-static静态库
  • 【优先算法】专题——前缀和
  • 洛谷[USACO08DEC] Patting Heads S
  • 详解Linux系统的终端(Terminal)以及分类(各种tty开头的设备文件)
  • 蓝桥杯python基础算法(2-1)——排序
  • PHP Composer:高效依赖管理工具详解
  • 鲸鱼算法 matlab pso