当前位置: 首页 > article >正文

分类 classificaton

b10049398220446a9618e92493250956.png

1)什么是分类?

在此之前,我们一直使用的都是回归任务进行学习;这里我们将进一步学习什么是分类,我们先从训练模型的角度来看看二者的区别。

82a92f15506b468a9e2e416f68dd3d04.png

 对于回归来说,它所作的是对模型输入相应的特征,然后模型给出相应的输出,需要让模型的输出和实际的标签值越接近越好;而对于分类来说,同样的是将相应的特征输入模型,模型输出相应的类型。

1.1问题一:

分类模型的输出不像回归模型一样输出是一个特定的数值,所以对于分类模型来说我们可以根据将不同的类别使用不同的数值来代替。

例如:

class1 --- 1

class2 --- 2

class3 --- 3

这样当模型认为输入的特征和class1更符合的话就会输出1,class3更符合的话就会输出3。但是又会出现新的问题,采用以上编码方式是否会导致模型认为class1和class2更相似;class1和class3更加的不同呢?因为class1和class2的距离上更近,比如网络的输出值是1.49,其实就表示很大概率是class1或者class2,这样其实就隐含表示class1或者class2更近一点。

1.2问题二:

对于有些分类任务来说,采用以上编码方式是不会具有问题的。比如使用升高和体重来预测小朋友的年级,例如:一年级 --- 1、二年级 --- 2、三年级 --- 3。这样是没问题的,因为一年级和二年级这两个类别来说是相对更近的, 一年级和三年级这两个类别来说是相对更远的。但是对于有的分类任务来说,再编码的时候就会产生这样的问题,于是在编码的时候采用one-hot vector的方式进行编码。

例如:

 eq?class1%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%201%20%5C%5C%200%20%5C%5C%200%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

 eq?class2%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%200%20%5C%5C%201%20%5C%5C%200%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

 eq?class3%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%200%20%5C%5C%200%20%5C%5C%201%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

采用这种编码方式的话,就不会出现以上这种问题了,这样的话,他们之间的距离都是一样的啦。

1.3问题三:

在回归问题的时候,我们构建的神经网络只能输出一个数值,但是对于分类问题来说,要是采用one-hot的编码方式对类别进行编码,那么对于网络的输出就不能只有一个,所以网络的结构也必须改变。

所以如下图所示,只需要多输出两个就行。

58aaed9f9f3b4fd5bf2b507ae315ddf7.png

至此对于一个分类任务的模型我们已经构架完成了,只不过是对于回归问题进行了一些小小的改进,但是其实对于分类问题来说,还有一些不太一样的问题。


 我们来最终对比一下回归任务和分类任务,分类任务最终的输出要和实际的标签纸越接近越好;对于分类来说,其最终的输出也应该与实际的类别标签纸越接近越好。

fbfc23417b5245f09ddfac263253aa7d.png

可是实际在最后一步输出的过程中,即最后网络输出了y之后,对于分类问题来说会再加上一个softmax使其输出 eq?y%5E%7B%27%7D ,然后希望的是  eq?y%5E%7B%27%7D和实际的标签值越接近越好。

2)为什要加softmax?

简单理解即使,其实对于网络的输出的三个值是可以为任何值,但是在最终的标签我们是希望在零到一之间的,所以通过softmax就可以将网络输出的值规格化到零到一之间。

softmax的工作过程:

  • 对其所以的输入的y值(网络最后的输出)取exp,也就是分别计算e^{y_i}
  • e^{y_i}进行求和
  • 用分别计算得到的e^{y_i}比上所有e^{y_i}的和就得到了每个数值最后被规格化后的数值。

例子:

输入softmax的三个数值是3、1、-3。 

  • 取exp。得到e^3=20,e^1=2.7,e^{-3}=0.05
  • 求和。20 + 2.7 + 0.05 = 22.75
  • 归一化。\frac{0.05}{22.75}\approx 0,\frac{2.7}{22.75}\approx 0.12,\frac{20}{22.75}\approx 0.88

其实在实际的分类问题当中,当分类任务是两个类别的时候,我们更常用的是使用sigmoid函数来进行最后的归一化;但是其实也可以使用softmax,他们二者在二分类的使用上无本质的区别。

2.1sigmoid函数:

sigmoid函数原型表达式如下:

                                                              sigmoid = \frac{1}{1+e{-x}}

以输入x_1,x_2为例子。

  • sigmoid:output(x_1)=\frac{1}{1+e{-x_1}}
  • softmax:output(x_1)=\frac{e^{x_1}}{e^{x_1}+e^{x_2}}=\frac{1}{1+e^{x_2-x_1}}=\frac{1}{1+e^{-(x_1-x_2)}}
  • 对于二分类来说可以进一步将softmax写成:softmax =\frac{1}{1+e^{-z_1}}

由此可得对于二分类问题来说,其二者的公式无本质区别,即理论上来说,二者是没有任何区别的。

sigmoid和softmax函数的本质区别: 

sigmoid函数用于多标签分类问题,选取多个标签作为正确答案,它是将任意值归一化为[0-1]之间,并不是不同概率之间的相互关联

Softmax函数用于多分类问题,即从多个分类中选取一个正确答案。Softmax综合了所有输出值的归一化,因此得到的是不同概率之间的相互关联。 

转载来源:深度学习随笔——Softmax函数与Sigmoid函数的区别与联系 - 知乎

Sigmoid函数针对两点分布提出。神经网络的输出经过它的转换,可以将数值压缩到(0,1)之间,得到的结果可以理解成分类成目标类别的概率P,而不分类到该类别的概率是(1 - P),这也是典型的两点分布的形式。

Softmax函数本身针对多项分布提出,当类别数是2时,它退化为二项分布。而它和Sigmoid函数真正的区别就在——二项分布包含两个分类类别(姑且分别称为A和B),而两点分布其实是针对一个类别的概率分布,其对应的那个类别的分布直接由1-P得出。

简单点理解就是,Sigmoid函数,我们可以当作成它是对一个类别的“建模”,将该类别建模完成,另一个相对的类别就直接通过1减去得到。而softmax函数,是对两个类别建模,同样的,得到两个类别的概率之和是1。

3)分类问题的损失函数

分类问题的损失函数同样的根据距离来计算,可以和之前的回归问题一样使用MSE误差来计算损失函数。但是更常用是使用下图中的 Cross- entropy来计算误差

为什么选择Cross- entropy而不是Mean Square error

其实我们实际在使用pytorch进行构建网络实现分类的任务的时候,我们会发现找不到softmax,这是因为,我们再构建网络后,使用Cross-entropy来计算误差的时候,其会自动再网络的最后一层加上softmax,在 pytorch中Cross-entropy和softmax被内置成为了一个整体。

例子:

在某个训练过程中网络输出的三个数值分别是y_1,y_2,y_3,然后再经softmax处理得到了最后的输出结果,我们分别使用Cross- entropy和Mean Square error进行求损失,然后根据损失计算下一步该往哪里走。 

如上图所示得到了Cross- entropy和Mean Square error 的损失图,在图的右下角都是损失最小的地方,即y_1的值变大, y_2的值变小就可以使得损失值变小;在图的左上角都是损失最大的地方,即y_1的值变小, y_2的值变大会使得损失值变大。 

所以当使用Mean Square error 的时候,很有可能会到达梯度不变的点,导致训练不下去(一般的训练过程训练不下去),但是对于 Cross- entropy 来说,确是有梯度了,可以进行下去。 

所以可以得出对于损失函数的设计都会影响最后的一个训练优化过程。

4)pytorch实现分类代码

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义超参数
batch_size = 64
learning_rate = 0.001
num_epochs = 10

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 下载并加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 定义 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # 每100个mini-batch打印一次
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
            running_loss = 0.0

print('训练完成')

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'测试准确率: {100 * correct / total:.2f}%')

接下来将一步一步剖析代码,我们先看网络架构。

定义SimpleCNN类,让其继承prtorch中的nn.Module,并且

# 定义 CNN 模型
class SimpleCNN(nn.Module): #SimpleCNN 继承自 nn.Module
    def __init__(self): #构造函数(初始化方法)
        super(SimpleCNN, self).__init__()  # 这一行代码调用了父类(nn.Module)的构造函数,确保在实例化SimpleCNN类时,也会执行父类的初始化操作。       

第一层网络:

self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

CIFAR-10图像都是3通道(RGB)的,尺寸为32x32像素。因此每个图像的数据形状是3x32x32,所以对于32个卷积核

  • 输入通道数:3,对应于 RGB 图像的三个通道。
  • 输出通道数:32,卷积核的数量,生成32个特征图。
  • 卷积核大小:3x3。
  • Padding:1,保持输出和输入的宽高相同。

CIFAR-10图像为3x32x32,输入通道是3,其实就对应着3个特征图,即RGB三个特征图。经过32个卷积核采样后,每个卷积核都能得到一个特征图,也就是32个特征图,但是在不加padding的时候得到的特征图都是4x4的,加了padding之后,就能保证最后卷积后得到额特征是还是6x6的。

第二层网络:

self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  • 输入通道数:32,来自前一层的输出。
  • 输出通道数:64。
  • 卷积核大小:3x3。
  • Padding:1。

 第三层网络——池化层

  • 池化大小:2x2。
  • 步幅:2,减少特征图尺寸,通常用于降采样。

self.pool = nn.MaxPool2d(2, 2)

全连接层

self.fc1 = nn.Linear(64 * 8 * 8, 512)

  • 输入大小:64 * 8 * 8,来自池化层展平后的特征数。
  • 输出大小:512,隐藏层的神经元个数。

self.fc2 = nn.Linear(512, 10)

  • 输入大小:512,来自前一层的输出。
  • 输出大小:10,对应于 CIFAR-10 数据集的10个类别。

self.relu = nn.ReLU()

ReLU:一种常用的激活函数,引入非线性,计算简单,能有效缓解梯度消失问题。

前向传播:

def forward(self, x):
    x = self.pool(self.relu(self.conv1(x)))
    x = self.pool(self.relu(self.conv2(x)))
    x = x.view(-1, 64 * 8 * 8)
    x = self.relu(self.fc1(x))
    x = self.fc2(x)
    return x

  • 卷积 + ReLU + 池化
    • 首先对输入应用卷积层和 ReLU 激活函数,然后进行最大池化。
  • 展平
    • 使用 view 方法将特征图展平成一维,便于连接到全连接层。
  • 全连接层 + ReLU
    • 输入到第一个全连接层并应用 ReLU 激活。
  • 输出层
    • 最后通过第二个全连接层得到输出。

这个简单的 CNN 通过卷积、激活、池化和全连接层的组合来提取图像特征并进行分类。


http://www.kler.cn/a/381556.html

相关文章:

  • C++浅拷贝与深拷贝
  • 从APP小游戏到Web漏洞的发现
  • 与IP网络规划相关的知识点
  • 408——计算机网络(持续更新)
  • 【双指针】【数之和】 LeetCode 633.平方数之和
  • 初始JavaEE篇——多线程(5):生产者-消费者模型、阻塞队列
  • 字典学习python
  • vue props无法被watch
  • 使用Spring Validation实现数据校验详解
  • AWTK-HarmonyOS NEXT 发布
  • 华为HarmonyOS借助AR引擎帮助应用实现虚拟与现实交互的能力4-检测环境中的平面
  • QML----复制指定下标的ListModel数据
  • 【基于轻量型架构的WEB开发】课程 12.4 页面跳转 Java EE企业级应用开发教程 Spring+SpringMVC+MyBatis
  • Python Matplotlib 子图绘制
  • 省级-能源结构数据(电力消费水平)(2000-2022年)
  • 【go从零单排】go三种结构体:for循环、if-else、switch
  • 【大数据学习 | HBASE】habse的表结构
  • vue前端面试题及答案2024
  • 飞书API-获取tenant_access_token
  • Melty 主体流程图
  • ctfshow文件包含web78~81
  • 八、1.STM32之DMA实验--DMA数据转运
  • 从传统服务器到虚拟化:虚拟机 VM 如何改变计算游戏规则?
  • 【spring】Cookie和Session的设置与获取(@CookieValue()和@SessionAttribute())
  • 企业HR如何选对一款智能招聘软件?
  • 加锁失效,非锁之过,加之错也|京东零售供应链库存研发实践