当前位置: 首页 > article >正文

如何得到深度学习模型的参数量和计算复杂度

1.准备好网络模型代码

import torch
import torch.nn as nn
import torch.optim as optim

# BP_36: 输入2个节点,中间层36个节点,输出25个节点
class BP_36(nn.Module):
    def __init__(self):
        super(BP_36, self).__init__()
        self.fc1 = nn.Linear(2, 36)  # 输入2个节点,中间层36个节点
        self.fc2 = nn.Linear(36, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# BP_64: 输入2个节点,中间层64个节点,输出25个节点
class BP_64(nn.Module):
    def __init__(self):
        super(BP_64, self).__init__()
        self.fc1 = nn.Linear(2, 64)  # 输入2个节点,中间层64个节点
        self.fc2 = nn.Linear(64, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# Bi-LSTM: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_LSTM(nn.Module):
    def __init__(self):
        super(Bi_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向LSTM
        self.fc1 = nn.Linear(72, 25)  # LSTM的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.lstm(x)  # 输出LSTM的结果
        x = self.fc1(x)
        return x

# Bi-GRU: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_GRU(nn.Module):
    def __init__(self):
        super(Bi_GRU, self).__init__()
        self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向GRU
        self.fc1 = nn.Linear(72, 25)  # GRU的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.gru(x)  # 输出GRU的结果
        x = self.fc1(x)
        return x

2.运行计算参数量和复杂度的脚本

import torch
# from net import BP_36
# from net import BP_64
# from net import Bi_LSTM
from net import Bi_GRU

from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# 统计Transformer模型的参数量和计算复杂度
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)


http://www.kler.cn/a/466337.html

相关文章:

  • HTML——56.表单发送
  • 基于氢氧燃料电池的分布式三相电力系统Simulink建模与仿真
  • 【Rust自学】10.6. 生命周期 Pt.2:生命周期的语法与例子
  • 『SQLite』详解运算符
  • 旧服务改造及微服务架构演进
  • XIAO ESP32 S3网络摄像头——2视频获取
  • 【图像处理】OpenCv + Python 实现 Photoshop 中的色彩平衡功能
  • 机器学习经典算法——逻辑回归
  • 在K8S中,Pod请求另一个Pod偶尔出现超时或延迟,如何排查?
  • 【LeetCode】803、打砖块
  • BurpSuite2024.11
  • JLINK V9插入电脑没反应
  • 基于深度学习的视觉检测小项目(二) 环境和框架搭建
  • pytorch张量高级索引介绍
  • Sublime Text4 4189 安装激活【 2025年1月3日 亲测可用】
  • LLM 中的 Decoder Only
  • df.set_index(‘name‘).groupby(‘team‘).apply(first_3, ‘Q1‘)
  • 被催更了,2025元旦源码继续免费送
  • 一文讲清楚webpack和vite原理
  • Vue 快速入门:开启前端新征程
  • 像品茶一样品设计模式,早日突破编码新境界。
  • 2025 年软件行业展望:除了 AI,还有更多精彩
  • STM32CUBE快速入门02
  • 免费下载 | 2024网络安全产业发展核心洞察与趋势预测
  • 【VUE】使用create-vue快速创建一个vue + vite +vue-route 等其他查看的工程
  • 私享樱花盛景:滨江一品苑,尊享春日浪漫