timm库加载的模型可视化
在深度学习中,模型的可视化有助于了解模型的结构和层级关系。以下是几种方式来可视化使用 timm 库加载的模型:
- 打印模型结构
torch.nn.Module 的子类(包括 timm 的模型)可以通过 print() 查看其结构:
import timm
# 加载模型
model = timm.create_model('resnet50', pretrained=True)
# 打印模型结构
print(model)
虽然这种方式可以快速查看模型结构,但对于复杂模型,结果可能会显得混乱。
- 使用 torchsummary 打印摘要信息
torchsummary 可以显示模型的每一层、输出形状和参数数量:
安装 torchsummary
pip install torch-summary
使用 torchsummary
from torchsummary import summary
import timm
加载模型
model = timm.create_model('resnet50', pretrained=True)
显示模型摘要信息
summary(model, input_size=(3, 224, 224)) # 假设输入是 3x224x224 的图像
这会输出模型每一层的输入/输出形状和参数数量。
- 使用 torchviz 可视化模型图
torchviz 可以生成模型的计算图。
安装 torchviz
pip install torchviz
生成计算图
from torchviz import make_dot
import timm
import torch
# 加载模型
model = timm.create_model('resnet50', pretrained=True)
# 模拟输入
dummy_input = torch.randn(1, 3, 224, 224)
# 获取模型的计算图
output = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))
# 保存为 PDF 或显示
dot.render("model_visualization", format="pdf") # 保存为 PDF 文件
生成的 PDF 文件会显示模型的计算图,包括张量流。
- 使用 torch.fx 绘制计算图
PyTorch 的 torch.fx 工具支持将模型转为计算图,并生成更清晰的可视化。
使用 torch.fx 绘制图
import torch.fx
import timm
from torch.fx.graph_module import GraphModule
# 加载模型
model = timm.create_model('resnet50', pretrained=True)
# 构建 GraphModule
traced = torch.fx.symbolic_trace(model)
# 打印计算图
print(traced.graph)
# 可进一步转换为可视化格式(如保存为文件等)
- 使用 netron 进行高级可视化
netron 是一个强大的工具,支持查看模型的详细结构和参数。
步骤
1. 将模型保存为 ONNX 格式。
2. 使用 netron 加载和可视化。
代码示例:导出为 ONNX
import timm
import torch
# 加载模型
model = timm.create_model('resnet50', pretrained=True)
# 模拟输入
dummy_input = torch.randn(1, 3, 224, 224)
# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11)
# 使用 netron 打开模型
启动 netron
pip install netron netron model.onnx
netron 将打开一个浏览器窗口,显示模型结构。
- 使用 hiddenlayer 可视化
hiddenlayer 可以生成直观的模型结构图。
安装 hiddenlayer
pip install hiddenlayer
可视化模型
import hiddenlayer as hl
import timm
import torch
加载模型
model = timm.create_model('resnet50', pretrained=True)
模拟输入
dummy_input = torch.randn(1, 3, 224, 224)
构建模型图
hl_graph = hl.build_graph(model, dummy_input)
hl_graph.save("model_structure.png") # 保存为图片
总结
• 快速查看结构:直接使用 print(model)。
• 结构摘要:使用 torchsummary。
• 计算图可视化:使用 torchviz 或 torch.fx。
• 高级交互式查看:导出为 ONNX 格式并使用 netron。
• 直观的结构图:使用 hiddenlayer。
根据需求选择适合的方式,简单场景用 print,复杂场景用 netron 或 torchviz。