pytorch 卷积神经网络可视化 通过HiddenLayer和PyTorchViz可视化网络(已解决)
1 HiddenLayer
创建一个简单的网络
import torch
import torch.nn as nn
import hiddenlayer as h
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 3, 1, 1),
nn.ReLU(),
nn.AvgPool2d(2, 2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(32 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU()
)
self.out = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
output = self.out(x)
return output
输出网络结构:
MyConvNet = ConvNet()
print(MyConvNet)
输出结果:
ConvNet(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc): Sequential(
(0): Linear(in_features=1568, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=64, bias=True)
(3): ReLU()
)
(out): Linear(in_features=64, out_features=10, bias=True)
)
通过HiddenLayer可视化网络
安装HiddenLayer
pip install hiddenlayer
绘制的基本程序如下:
import hiddenlayer as h
vis_graph = h.build_graph(MyConvNet, torch.zeros([1 ,1, 28, 28])) # 获取绘制图像的对象
vis_graph.theme = h.graph.THEMES["blue"].copy() # 指定主题颜色
vis_graph.save("./demo1.png") # 保存图像的路径
2 PyTorchViz
安装库:
pip install torchviz
只使用可视化函数make_dot()
来获取绘图对象,基本使用和HiddenLayer
差不多,不同的地方在于PyTorch
绘图之前可以指定一个网络的输入值和预测值。
from torchviz import make_dot
x = torch.randn(1, 1, 28, 28).requires_grad_(True) # 定义一个网络的输入值
y = MyConvNet(x) # 获取网络的预测值
MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
MyConvNetVis.format = "png"
# 指定文件生成的文件夹
MyConvNetVis.directory = "data"
# 生成文件
MyConvNetVis.view()
打开与上述代码相同根目录下的data文件夹,里面会有一个.gv
文件和一个.png
文件,其中的.gv
文件是Graphviz工具生成图片的脚本代码,.png
是.gv
文件编译生成的图片,直接打开.png
文件就行。
默认情况下,上述程序运行后会自动打开.png文件
欢迎点赞 收藏 加 关注
参考:
PyTorch下的可视化工具 - 知乎