基于Pytorch的可视化工具

深度学习网络通常具有很深的层次结构,而且层与层之间通常会有并联、串联等连接方式。当使用PyTorch建立一个深度学习网络并输出文本向读者展示网络的连接方式是非常低效的,所以需要有效的工具将建立的深度学习网络结构有层次化的展示,这就需要使用相关的深度学习网络结构可视化库。

3.1 准备网络

import torch
import torch.nn as nn
import torchvision
import torchvision.utils as vutils
from torch.optim import SGD
import torch.utils.data as Data
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

"""
下面导入手写字体数据,并将数据处理为数据加载器
"""
train_data=torchvision.datasets.MNIST(
    root="./Dataset",
    train=True,
    #将数据转化为torch使用的张量,取值范围为[0-1]
    transform=torchvision.transforms.ToTensor(),
    download=False#因为数据已经下载过所以这里不在下载
    )
#定义一个数据加载器
train_loader=Data.DataLoader(
    dataset=train_data,#使用的数据集
    batch_size=128,#批处理样本大小
    shuffle=True,
    num_workers=0
)
#准备需要的测试集
test_data=torchvision.datasets.MNIST(
    root="./Dataset",
    train=False,
    download=False
)
#为数据添加一个通道维度,并且取值范围缩放到[0-1]之间
test_data_x=test_data.data.type(torch.FloatTensor)/255.0
test_data_x=torch.unsqueeze(test_data_x,dim=1)
test_data_y=test_data.targets#测试集的标签
print(test_data_x.shape)
print(test_data_y.shape)
"""
(1)针对训练集使用torchvision.datasets.MNIST()函数导入数据,
并将其像素值转化到0~1之间,然后使用Data.DataLoader()函数定义一个数据加载器,
每个batch包含128张图像。
(2)针对测试数据,同样使用torchvision.datasets.MNIST()函数导入数据,
但不将数据处理为数据加载器,而是将整个测试集作为一个batch,方便计算模型在测试集上的预测精度。
"""
#搭建一个卷积神经网络
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet,self).__init__()
        #定义第一个卷积层
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,#输入的feature_map
                out_channels=16,#输出的feature_map
                kernel_size=3,#卷积核尺寸
                stride=1,#卷积核步长
                padding=1,#进行填充
            ),
            nn.ReLU(),#激活函数
            nn.AvgPool2d(
                kernel_size=2,#平均值池化层,使用2*2
                stride=2,#池化步长为2
            ),
        )
        #定义第二个卷积层
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,3,1,1),
            nn.ReLU(),#激活函数
            nn.MaxPool2d(2,2)#最大值池化
        )
        #定义全连接层
        self.fc=nn.Sequential(
            nn.Linear(
                in_features=32*7*7,#输入特征
                out_features=128#输出特征
            ),
            nn.ReLU(),#激活层
            nn.Linear(128,64),#
            nn.ReLU()#激活层
        )
        self.out=nn.Linear(64,10)#最后的分类层
    def forward(self,x):
        x=self.conv1(x),
        x=self.conv2(x),
        x=x.view(x.size(0),-1),#展平多维的卷积图层
        x=self.fc(x),
        output=self.out(x),
        return output
myConvNet=ConvNet()

3.2 网络结构的可视化---PytorchViz

从定义网络和网络的输出可以看出,在myConvNet网络结构中,共包含两个使用nn.Sequential()函数连接的卷积层,即conv1和conv2,每个层都包含有卷积层、激活函数层和池化层。在fc层中,包含两个全连接层和激活函数层,out层则由一个全连接层构成。通过文本输出myConvNet网络的网络结构得到上面的输出结果,但这并不容易让读者理解在网络中层与层之间的连接方式,所以需要将PyTorch搭建的深度学习网络进行可视化,通过图像来帮助读者理解网络层与层之间的连接方式。PyTorchViz库是一个可视化网络结构的库,学习和了解这些库对网络结构进行可视化,可以帮助我们查看、理解所搭建深度网络的结构。

from torchviz import make_dot
x=torch.randn(1,1,28,28).requires_grad_(True)
y=myConvNet(x)
#make_dot用来得到网络的可视化图像
myConvNets=make_dot(y,params=dict(list(myConvNet.named_parameters())+[("x",x)]))
myConvNets.format='png'#指定可视化图像的格式
myConvNets.directory="model_graph/"#指定图像保存的文件夹
myConvNets.view()

得到的图像如下图所示:

 3.3 训练过程可视化---TensorboardX

网络结构可视化,主要是帮助使用者理解所搭建的网络或检查搭建网络时存在的错误。而网络训练过程的可视化,通常用于监督网络的训练过程或呈现网络的训练效果。tensorboardX是帮助pytorch使用tensornoard工具来可视化的库,在tensorboardX库中,提供了多种向tensorboard中添加事件的函数。

函数功能用法
SummaryWriter()创建编写器,保存日志writer=SummaryWriter()
writer.add_scalar()添加标量writer.add_scalar(''myscalar",value,iteration)
writer.add_image()添加图像writer.add_image("imresult",x,iteration)
writer.add_histogram()添加直方图writer.add_histogram('hist',array,iteration)
writer.add_graph()添加网络结构writer.add_graph(model,input_to_model=None)
writer.add_audio()添加音频add_audio(tag,audio,iteration,sample_rate)
writer.add_text()添加文本writer.add_text(tag,text_string,global_step=None)

下面针对建立好的MNIST手写字体识别网络,使用rensorboardX库对网络在训练过程中的损失函数的变化情况、精度变化情况、权重分布等内容进行可视化,程序如下:

from tensorboardX import SummaryWriter
SumWriter=SummaryWriter(log_dir='./log')#日志保存路径
#定义优化器
optimizer=torch.optim.Adam(myConvNet.parameters(),lr=0.003)
#定义损失函数
loss_func=nn.CrossEntropyLoss()
train_loss=0
print_step=100#每经过100次迭代,输出损失
#对模型进行迭代训练,对所有的数据训练epoch轮
for epoch in range(5):
    #对训练数据的加载器进行迭代计算
    for step,(b_x,b_y) in enumerate(train_loader):
        #计算每个batch上的损失
        output=myConvNet(b_x)
        loss=loss_func(output,b_y)#交叉熵损失函数
        optimizer.zero_grad()#每个迭代步的梯度初始化为0
        loss.backward()#损失后向传播
        optimizer.step()#使用梯度进行优化
        train_loss=train_loss+loss#计算累加损失
        niter=epoch*len(train_loader)+step+1#计算迭代次数
        #计算每经过print_step次迭代后的输出
        if niter % print_step==0:
            #为日志添加训练集损失函数
            SumWriter.add_scalar("train loss",train_loss.item() / niter,global_step=niter)
            #计算在测试集上的精度
            output=myConvNet(test_data_x)
            _,pre_lab=torch.max(output,1)
            acc=accuracy_score(test_data_y,pre_lab)
            #为日志添加测试集上的预测精度
            SumWriter.add_scalar("test acc",acc.item(),niter)
            #为日志中添加训练数据的可视化图像,使用当前batch的图像
            #将一个batch的数据进行预处理
            b_x_im=vutils.make_grid(b_x,nrow=12)
            SumWriter.add_image('train image sample',b_x_im,niter)
            #使用直方图可视化网络中参数的分布情况
            for name,param in myConvNet.named_parameters():
                SumWriter.add_histogram(name,param.data.numpy(),niter)

在网络训练完毕之后,会得到网络的训练过程文件,该文件的可视化结果可以通过tensorboard可视化工具进行查看。

日志文件保存路径

 在conda环境下执行>tensorboard --logdir="文件路径"

 成功后,tensorboard会返回一个本地网址链接,浏览器打开该链接即可查看到可视化界面

3.5 Visdom可视化

Visdom库中包含多种用于可视化图像的接口

可视化函数功能描述
vis.image可视化一张图像
vis.image可视化一个batch的图像,或者一个图像列表
vis.text可视化文本
vis.audio用于播放音频
vis.matplot可视化Matplotlib的图像
vis.scatter2D或者3D的散点图
vis.line线图
vis.stem茎叶图
vis.heatmap热力图
vis.bar条形图
vis.histogram直方图
vis.boxplot盒形图
vis.surf曲面图
vis.contour等高线图
vis.quiver箭头图
vis.video播放音频
vis.mesh网格图

下面使用具体的数据集进行可视化图像:

import numpy as np
import torch
from visdom import Visdom
from sklearn.datasets import load_iris
iris_x,iris_y=load_iris(return_X_y=True)
print(iris_x.shape)
print(iris_y.shape)
"""
上面的程序导入了鸢尾花数据集,包含3类数据,150个样本,每个样本包含4个特征。
"""
vis=Visdom()
#2D散点图
vis.scatter(iris_x[:,0:2],Y=iris_y+1,win="windows1",env="main")
#3D散点图
vis.scatter(iris_x[:,0:3],Y=iris_y+1,win="3D 散点图",env="main",opts=dict(markersize=4,xlabel="特征一",ylabel="特征二"))
"""
程序中使用vis = Visdom()初始化一个绘图对象,通过vis.scatter()为对象添加散点图。
在该函数中,如果输入的X为二维则可得到2D散点图,如果输入的X为三维则可得到3D散点图,
其中参数Y用于指定数据的分组情况,参数win指定图像的窗口名称,参数env则指定图像所在的环境。
可以发现两幅图像都在主环境main中。图像的其他设置可使用opts参数通过字典的形式设置。
在上述初始化的可视化图像环境main中,继续添加其他窗口,以绘制不同类型的图像,如添加折线图.
"""
#添加折线图
x=torch.linspace(-6,6,100).view((-1,1))
sigmoid=torch.nn.Sigmoid()
sigmoidy=sigmoid(x)
tanh=torch.nn.Tanh()
tanhy=tanh(x)
relu=torch.nn.ReLU()
reluy=relu(x)
#连接3个张量
ploty=torch.cat((sigmoidy,tanhy,reluy),dim=1)
plotx=torch.cat((x,x,x),dim=1)
vis.line(Y=ploty,X=plotx,win="line plot",env="main",opts=dict(dash=np.array(["solid","dash","dashdot"]),legend=["Sigmoid","Tanh","ReLU"]))
"""
上面的程序中,可视化出了sigmoid、Tanh和ReLU三种激活函数的图像。
在可视化折线时,使用vis.line()函数进行绘图,图像在环境main中,
通过win参数指定窗口名称为line plot,然后通过opts参数为不同的线设置了不同的线型。
"""
#添加茎叶图
x=torch.linspace(-6,6,100).view((-1,1))
y1=torch.sin(x)
y2=torch.cos(x)
#连接2个变量
plotx=torch.cat((y1,y2),dim=1)
ploty=torch.cat((x,x),dim=1)
vis.stem(X=plotx,Y=ploty,win="stem plot",env="main",
         #设置图例
        opts=dict(legend=["sin","cos"],title="茎叶图") )
"""
上面的程序中,可视化出了正弦和余弦函数的茎叶图。
在可视化时通过vis.stem()函数绘图,图像在环境main中,
通过win参数指定窗口名称为stem plot,然后通过opts参数为图像添加图例和标题。
"""
#添加热力图
iris_corr=torch.from_numpy(np.corrcoef(iris_x,rowvar=False))
vis.heatmap(iris_corr,win="heatmap",env="main",
            #设置每个特征的名称
            opts=dict(rownames=["x1","x2","x3","x4"],columnnames=["x1","x2","x3","x4"],title="热力图"))
"""
程序中可视化出了鸢尾花数据集中4个特征的相关系数热力图。
在可视化时通过vis.heatmap()函数进行绘图,图像在环境main中,
通过win参数指定窗口名称为heatmap,然后通过opts参数为图像添加X轴的变量名称、Y轴变量名称和标题。
"""

注意:在使用visdom可视化图像之前,应该在命令行激活visdom服务,否则程序报错

python -m visdom.server

命令执行成功后会返回一个本地链接: http://localhost:8097

然后在pycharm中执行代码,再用浏览器打开上述链接即可得到下面的可视化图像: 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/7447.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

chatGPT的未来应用有哪些-ChatGPT对未来工作的影响

ChatGPT对未来的影响 ChatGPT 是一种先进的自然语言处理技术,能够处理和理解大量的自然语言数据和信息,具有广泛的应用价值。以下是 ChatGPT 可能对未来的影响: 改变人与计算机的交互方式。ChatGPT 的普及应用,将使得人们可以通过…

网络安全书籍推荐+网络安全面试题合集

一、计算机基础 《深入理解计算机系统》《鸟哥的Linux私房菜》《TCP/IP详解(卷1:协议)》《HTTP权威指南》《Wireshark数据包分析实战》《Wireshark网络分析的艺术》《Wireshark网络分析就这么简单》 二、网络渗透 《白帽子讲Web安全》《Web…

2023年五一数学建模竞赛来袭

1.竞赛介绍 五一数学建模竞赛由江苏省工业与应用数学学会,中国矿业大学,徐州市工业与应用数学学会联合举办,历史悠久,距离第一届比赛已经有20年历史,可以说是仅次于高教社杯国赛的一项数学建模竞赛。比较适合数模竞赛…

浅谈JVM(五):虚拟机栈帧结构

上一篇: 浅谈JVM(一):Class文件解析 浅谈JVM(二):类加载机制 浅谈JVM(三):类加载器和双亲委派 浅谈JVM(四):运行时数据区 5.虚拟机栈帧结构 ​ 方法是程序执行的最小单元,每个方法被执行时都会创建一个栈帧…

关于CH32F203程序下载方式说明

关于CH32F203程序下载方式说明🎉好久没有写有关wch单片机的相关内容了,具体焊接完2块CH32F203,发现烧写程序遇到了各种囧事。 📓CH32F203程序下载方式 🔨通过串口下载。接口为PA9和PA10不知道是不是各厂商之间默契的规…

Linux VIM编辑器常用指令

普通模式的基本指令 按键作用yy 复制一行 通常会与p一起使用p将复制的内容写出 数字yy 从当前行往下数数字行进行复制y^复制当前行的起始位到光标的前一位y$复制光标当前位置到行末尾yw复制光标所在的位置之后(包括光标)的(不完整&#xff0…

ffmpeg关于视频前几秒黑屏的问题解决

关于音频播放器视频前两秒黑屏的解决,及QtAV和ffmpeg的环境搭建(软件包可以找李青璠提供,也可以自己下)首先我们可以参考下面两个博客进行ffmpeg的搭建,第一个博客的问题可以在第二个博客里寻求方法解决。其中第一个博…

多线程的锁策略

文章目录前言一.乐观锁与悲观锁二.轻量级锁和重量锁三.自旋锁和挂起等待锁四.互斥锁和读写锁五.不可重入锁和可重入锁六.公平锁和非公平锁前言 其实这里指的锁策略,不只只是线程才存在的。也不只是针对Java的,我现在就即将介绍常见的锁策略。 一.乐观锁…

Python 自动化指南(繁琐工作自动化)第二版:八、输入验证

原文:https://automatetheboringstuff.com/2e/chapter8/ 输入验证代码检查用户输入的值,比如来自input()函数的文本,格式是否正确。例如,如果您希望用户输入他们的年龄,您的代码不应该接受无意义的答案,如负…

中间表示- 三地址码

使用三地址码的编译器结构 三地址码的基本思想 (1)给每个中间变量和计算结果命名,没有复合表达式 (2)只有最基本的控制流,没有各种控制结构(if、do、while、for等等),只…

2 新建工程步骤

2 新建工程步骤 0.建立工程文件夹 选择一个程序储存文件,新建一个2-1 STM32工程模板文件夹,在2-1 STM32工程模板文件夹新建一个Start,User,Library文件夹 1.Keil中新建工程,选择型号 打开keil5,project->new pr…

045:cesium加载OpenStreetMap地图

第045个 点击查看专栏目录 本示例的目的是介绍如何在vue+cesium中加载加载OpenStreetMap地图。 直接复制下面的 vue+cesium源代码,操作2分钟即可运行实现效果. 注意OpenStreetMap国内加载有问题,需要曲线救图。 文章目录 示例效果配置方式示例源代码(共79行)相关API参考:…

询问ChatGPT的高质量答案艺术——提示工程指南(更新中……)

目录前言一、提示工程简介二、提示技巧2-1、生成法律文件2-2、添加提示技巧三、角色扮演3-1、智能手机产品描述3-2、添加角色扮演四、标准提示4-1、写一篇有关于新智能手机的评论4-2、添加标准提示、角色提示、种子词提示等等五、示例很少、或者没有示例5-1、生成一个手机配置六…

什么是服务架构?微服务架构的优势又是什么?

文章目录1.1 单体架构1.2 微服务架构1.3 单体架构和微服务架构的区分1.4 两种服务架构的优劣点1.4.1 单体架构1.4.2 微服务架构1.5 总结1.1 单体架构 单体架构(Monolithic Architecture)是一种传统的应用程序架构模式,它指的是将一个应用程序…

聚会Party

前言 加油 原文 聚会常用会话 ❶ He spun his partner quickly. 他令他的舞伴快速旋转起来。 ❷ She danced without music. 她跳了没有伴乐的舞蹈。 ❸ The attendants of the ball are very polite. 舞会的服务员非常有礼貌。 ❶ Happy birthday to you! 祝你生日快乐!…

剪枝与重参第四课:NVIDIA的2:4剪枝方案

目录NVIDIA的2:4 pattern稀疏方案前言1.稀疏性的研究现状2.图解nvidia2-4稀疏方案3.训练策略4.手写复现4.1 大体框架4.2 ASP类的实现4.3 mask的实现4.4 模型初始化4.5 Layer嵌入稀疏特性4.6 优化器初始化4.7 拓展-dynamic function assignment4.8 完整示例代码总结NVIDIA的2:4 …

做了个springboot接口参数解密的工具,我给它命名为万能钥匙(已上传maven中央仓库,附详细使用说明)

前言:之前工作中做过两个功能,就是之前写的这两篇博客,最近几天有个想法,给它做成一个springboot的start启动器,直接引入依赖,写好配置就能用了 springboot使用自定义注解实现接口参数解密,普通…

4.5--计算机网络之基础篇--1.模型分层--(复习+深入)---好好沉淀,加油呀

1.TCP/IP模型的分层 1.1.为什么要有 TCP/IP 网络模型? 对于同一台设备上的进程间通信,有很多种方式,比如有管道、消息队列、共享内存、信号等方式; 而对于不同设备上的进程间通信,就需要网络通信,而设备是…

Elasticsearch:索引状态是红色还是黄色?为什么?

在我之前文章 “Elasticsearch:如何调试集群状态 - 定位错误信息” 中,我有详细介绍如何调试集群状态。在今天的文章中,我将详细介绍如何故障排除和修复索引状态。 Elasticsearch 是一个伟大而强大的系统,特别是创建一个可扩展性极…

51单片机-LED篇

目录准备工作点亮一个LED灯写程序烧录LED闪烁延时代码Delay500ms烧录LED流水灯代码对LED流水灯代码进行优化,增加复用性延时代码代码准备工作 使用到的单片机是普中51单片机 使用到的软件是Keil uVision5和stc-isp 点亮一个LED灯 写程序 首先通过Keil uVision5…
最新文章