当前位置: 首页 > article >正文

timm库加载的模型可视化

在深度学习中,模型的可视化有助于了解模型的结构和层级关系。以下是几种方式来可视化使用 timm 库加载的模型:

  1. 打印模型结构
torch.nn.Module 的子类(包括 timm 的模型)可以通过 print() 查看其结构:

import timm

# 加载模型
model = timm.create_model('resnet50', pretrained=True)

# 打印模型结构
print(model)

虽然这种方式可以快速查看模型结构,但对于复杂模型,结果可能会显得混乱。

  1. 使用 torchsummary 打印摘要信息

torchsummary 可以显示模型的每一层、输出形状和参数数量:

安装 torchsummary

pip install torch-summary

使用 torchsummary

from torchsummary import summary
import timm

加载模型

model = timm.create_model('resnet50', pretrained=True)

显示模型摘要信息

summary(model, input_size=(3, 224, 224))  # 假设输入是 3x224x224 的图像

这会输出模型每一层的输入/输出形状和参数数量。

  1. 使用 torchviz 可视化模型图

torchviz 可以生成模型的计算图。

安装 torchviz

pip install torchviz

生成计算图

from torchviz import make_dot
import timm
import torch

# 加载模型
model = timm.create_model('resnet50', pretrained=True)

# 模拟输入
dummy_input = torch.randn(1, 3, 224, 224)

# 获取模型的计算图
output = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))

# 保存为 PDF 或显示
dot.render("model_visualization", format="pdf")  # 保存为 PDF 文件

生成的 PDF 文件会显示模型的计算图,包括张量流。

  1. 使用 torch.fx 绘制计算图

PyTorch 的 torch.fx 工具支持将模型转为计算图,并生成更清晰的可视化。

使用 torch.fx 绘制图

import torch.fx
import timm
from torch.fx.graph_module import GraphModule

# 加载模型
model = timm.create_model('resnet50', pretrained=True)

# 构建 GraphModule
traced = torch.fx.symbolic_trace(model)

# 打印计算图
print(traced.graph)

# 可进一步转换为可视化格式(如保存为文件等)
  1. 使用 netron 进行高级可视化

netron 是一个强大的工具,支持查看模型的详细结构和参数。

步骤

1.	将模型保存为 ONNX 格式。
2.	使用 netron 加载和可视化。

代码示例:导出为 ONNX

import timm
import torch

# 加载模型
model = timm.create_model('resnet50', pretrained=True)

# 模拟输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11)

# 使用 netron 打开模型

启动 netron

pip install netron netron model.onnx

netron 将打开一个浏览器窗口,显示模型结构。

  1. 使用 hiddenlayer 可视化

hiddenlayer 可以生成直观的模型结构图。

安装 hiddenlayer

pip install hiddenlayer

可视化模型

import hiddenlayer as hl
import timm
import torch

加载模型

model = timm.create_model('resnet50', pretrained=True)

模拟输入

dummy_input = torch.randn(1, 3, 224, 224)

构建模型图

hl_graph = hl.build_graph(model, dummy_input)
hl_graph.save("model_structure.png")  # 保存为图片

总结

•	快速查看结构:直接使用 print(model)。
•	结构摘要:使用 torchsummary。
•	计算图可视化:使用 torchviz 或 torch.fx。
•	高级交互式查看:导出为 ONNX 格式并使用 netron。
•	直观的结构图:使用 hiddenlayer。

根据需求选择适合的方式,简单场景用 print,复杂场景用 netron 或 torchviz。


http://www.kler.cn/a/408744.html

相关文章:

  • IIFE - 立即执行函数
  • Python入门(14)--数据分析基础
  • 【100ask】IMX6ULL开发板用SPI驱动RC522模块
  • 使用llama.cpp进行量化和部署
  • 海洋通信船舶组网工业4G路由器应用
  • Mac 修改默认jdk版本
  • 【Python-办公自动化】实现自动化输出模板表格报告
  • MongoDB 中设置登录账号密码可以通过以下步骤实现
  • 基于SSM的婚庆管理系统+LW示例参考
  • 了解rk3588单片机
  • 大模型工程化部署:使用FastChat部署基于OpenAI API兼容大模型服务
  • 应用案例 | 西门子能源选用ASPION G-Log 2冲击记录仪,揭秘高压开关设备运输背后的安全保障
  • UG NX二次开发(C++)-UIStyler-指定平面的对象和参数获取
  • 零基础学指针(上)
  • Python爬取豆瓣电影全部分类数据并存入数据库
  • 【大数据学习 | Spark-Core】关于distinct算子
  • STM32完全学习——使用标准库完成PWM输出
  • Spring Cloud Consul实现选举机制
  • springboot 整合 rabbitMQ (延迟队列)
  • js函数声明
  • 在SQLyog中导入和导出数据库
  • 在复现SDXL-Turbo和stable-diffusion-2-1系列大模型过程中遇到的问题以及解决方案
  • 机器学习周志华学习笔记-第5章<神经网络>
  • 自动化运维-Linux通用性日志切割脚本
  • 接口性能优化宝典:解决性能瓶颈的策略与实践
  • neo4j图数据库community-5.50创建多个数据库————————————————