当前位置: 首页 > article >正文

pytorch 卷积神经网络可视化 通过HiddenLayer和PyTorchViz可视化网络(已解决)

1  HiddenLayer

创建一个简单的网络

import torch
import torch.nn as nn

import hiddenlayer as h
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.out = nn.Linear(64, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        output = self.out(x)
        return output

输出网络结构:

MyConvNet = ConvNet()
print(MyConvNet)

输出结果:

ConvNet(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1568, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
  )
  (out): Linear(in_features=64, out_features=10, bias=True)
)

 通过HiddenLayer可视化网络

安装HiddenLayer

pip install hiddenlayer

绘制的基本程序如下:

import hiddenlayer as h
 vis_graph = h.build_graph(MyConvNet, torch.zeros([1 ,1, 28, 28]))   # 获取绘制图像的对象
 vis_graph.theme = h.graph.THEMES["blue"].copy()     # 指定主题颜色
 vis_graph.save("./demo1.png")   # 保存图像的路径

2 PyTorchViz

安装库:

pip install torchviz

只使用可视化函数make_dot()来获取绘图对象,基本使用和HiddenLayer差不多,不同的地方在于PyTorch绘图之前可以指定一个网络的输入值和预测值。

from torchviz import make_dot
x = torch.randn(1, 1, 28, 28).requires_grad_(True)  # 定义一个网络的输入值
y = MyConvNet(x)    # 获取网络的预测值
 ​
MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
MyConvNetVis.format = "png"
# 指定文件生成的文件夹
MyConvNetVis.directory = "data"
# 生成文件
MyConvNetVis.view()

打开与上述代码相同根目录下的data文件夹,里面会有一个.gv文件和一个.png文件,其中的.gv文件是Graphviz工具生成图片的脚本代码,.png.gv文件编译生成的图片,直接打开.png文件就行。

默认情况下,上述程序运行后会自动打开.png文件

 

欢迎点赞 收藏  加 关注

参考:

PyTorch下的可视化工具 - 知乎

 


http://www.kler.cn/a/589890.html

相关文章:

  • java学习总结(八):Spring boot
  • 2025深圳国际数字能源展全球招商启动,聚焦能源产业数字化转型
  • 【C++】*和到底如何使用?关于指针的一些理解
  • OpenCV实现图像特征提取与匹配
  • 最小二乘法的算法原理
  • 【React】useEffect、useLayoutEffect底层机制
  • 工业物联网的“边缘革命”:研华IoT Edge 设备联网与边缘计算的突破与实践
  • 蓝桥杯[每日一题] 模拟题:蚂蚁感冒(java版)
  • 项目实战系列:基于瑞萨RA6M5构建多节点OTA升级-系统设计<一>
  • ArcGIS Pro中加载在线地图的详细指南
  • 3.1 Spring Boot性能优化:从线程池调优到JVM参数配置
  • java web 安全,如何认证客户端?时间戳和noce如何抵御重放攻击?
  • 8051汇编--条件转移指令
  • 网络框架OkHttp与Retrofit原理剖析
  • Linux-c-粘住位
  • uni-app学习笔记——自定义模板
  • [Java实战]Spring Boot服务CPU 100%问题排查:从定位到解决
  • 爬虫逆向:详细讲述Android底层原理及机制
  • 职业教育五金建设改革解析
  • 0基础 | 直流稳压电源专题3