神经网络搭建实战与Sequential的使用
一、需要处理的图像
二、对上述图片用代码表示:
import torch from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear class SUN(nn.Module): def __init__(self): super(SUN, self).__init__() self.conv1 = Conv2d(3, 32, 5, padding=2) self.maxpool1 = MaxPool2d(2) self.conv2 = Conv2d(32, 32, 2,padding=2) self.maxpool2 = MaxPool2d(2) self.conv3 = Conv2d(32, 64, 5, padding=2) self.maxpool3 = MaxPool2d(2) self.flatten = Flatten() self.linear1 = Linear(1024, 64) self.linear2 = Linear(64, 10) def forward(self,x): x = self.conv1(x) x = self.maxpool1(x) x = self.conv2(x) x = self.maxpool2(x) x = self.conv3(x) x = self.maxpool3(x) x = self.flatten(x) x = self.linear1(x) x = self.linear2(x) return x sun = SUN() print(sun) # 写完网络检查网络的正确性,因为即使改变其中的一些参数,该网络还是能够正常的运行,所以需要检验 # 创建一个假象的输入 input = torch.ones((64,3,32,32)) output = sun(input) print(output.shape)
实现的结果:
三、用 Sequential简化
但是,在class的使用中,频繁的写self.....是非常复杂,不简洁的,于是有了Sequential:
代码如下简洁:
import torch from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential class SUN(nn.Module): def __init__(self): super(SUN, self).__init__() self.model1 = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 2, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) ) def forward(self, x): x =self.model1(x) return x sun = SUN() print(sun) # 写完网络检查网络的正确性,因为即使改变其中的一些参数,该网络还是能够正常的运行,所以需要检验 # 创建一个假象的输入 input = torch.ones((64,3,32,32)) output = sun(input) print(output.shape)
实现的结果是一样的,但是,较为的简洁。
四、使用tensorboard可视化
# 使用tensorboard来可视化: writer = SummaryWriter("logs_seq") writer.add_graph(sun, input) writer.close()
注意,此处使用的是add_graph()。
tensorboard中的add_graph
方法用于可视化ptorch模型的计算图。TensorBoard是一个强大的可视化工具,它允许用户以交互式的方式查看和理解模型的训练过程和结构。在PyTorch中,add_graph
方法是SummaryWriter类的一个功能,它可以将PyTorch模型的计算图以图形化的形式展示出来。通过这种方法,用户可以直观地看到模型中各个操作之间的依赖关系,以及数据在模型中的流动情况。这对于理解模型的内部工作机制、调试模型以及优化模型设计都非常有帮助。
tensorboard显示不出来的问题:使用了下述语句查看:
tensorboard --logdir =learningplan1/logs_seq
结果:
对其改正:加入端口6007
tensorboard --logdir=learningplan1/logs_seq --port=6007
最终正确查看,注:双击可以打开网络:
输入,经过搭建的SUN网络到达输出。
通过双击网络模块,可查询相关的参数等:
网络搭建成功。