当前位置: 首页 > article >正文

神经网络——CIFAR10小实战

1.引子

在这里插入图片描述
在这里插入图片描述

Sequential的使用:将网络结构放入其中即可,可以简化代码。
找了一个对CIFAR10进行分类的模型。

2.代码实战

from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3, 32, 5, padding=2)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, 5, padding=2)
        self.maxpool2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = MaxPool2d(2)
        self.flatten = Flatten()
        self.linear1 = Linear(1024, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

tudui=Tudui()
print(tudui)

在这里插入图片描述

nn.Flatten()和torch.flatten()有相同的效果。

3.Sequential

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self, x):
        x=self.model1(x)
        return x

tudui=Tudui()
print(tudui)
## 创建一个指定形状的 ones 张量
input=torch.ones((64,3,32,32))
output=tudui(input)
print(output.shape)

使用Sequential可以很大程度地简化代码。

4.利用TensorBoard进行数据可视化

使用SummaryWriter的add_graph()方法进行数据可视化。

writer=SummaryWriter("logs_sqe")
writer.add_graph(tudui,input)
writer.close()

在这里插入图片描述

基本的网络搭建到此结束。


http://www.kler.cn/a/286323.html

相关文章:

  • 【设计模式-行为型】备忘录模式
  • “星门计划对AI未来的意义——以及谁将掌控它”
  • springboot使用rabbitmq
  • 快速提升网站收录:避免常见SEO误区
  • ASP.NET代码审计 SQL注入篇(简单记录)
  • C++ 静态变量static的使用方法
  • 如何构建大型超市数据处理系统?Java SpringBoot搭配MySQL,实现高效数据管理!
  • Axure RP10安装教程(Pro版)
  • 考试评分系统设计与实现/基于django的在线考试系统
  • 发布npm包到GitLab教程
  • 人工智能和机器学习5 (复旦大学计算机科学与技术实践工作站)语言模型相关的技术和应用、通过OpenAI库,调用千问大模型,并进行反复询问等功能加强
  • 【网络安全】服务基础第一阶段——第四节:Windows系统管理基础---- NTFS安全权限与SMB文件共享服务器
  • Android游戏开发常见知识合集(Unity3D)
  • 距离向视数(Range Looks)方位向视数(Azimuth Looks)
  • MLM之Qwen:Qwen2-VL的简介、安装和使用方法、案例应用之详细攻略
  • 比较一下React与Vue
  • 《机器学习》—— K-means 聚类算法
  • 【微处理器系统原理与应用设计】微处理器的基本架构之组成原理和系统结构
  • 解决Qt报“undefined reference to vtable for“错误
  • 科技改变搜索习惯:Anytxt Searcher,重新定义你的信息获取方式!
  • 【王树森】Transformer模型(2/2): 从Attention层到Transformer网络(个人向笔记)
  • Java智慧社区全能平台集成跑腿家政及本地生活服务商城系统小程序源码
  • MySQL事务处理详解:实现数据一致性与隔离性的艺术
  • 【分层强化学习】Option Critic 的 CartPole-v1 的简单实例
  • MATLAB 地面点构建三角网(83)
  • 事务代码中加synchronized锁引发的bug