当前位置: 首页 > article >正文

机器学习小补充(加深理解)

1. 分类交叉熵损失(Categorical Crossentropy)

定义:当标签以独热编码形式表示时使用。

原理:在多分类问题中,分类交叉熵损失用于计算模型预测的概率分布与实际分布之间的差异。模型输出的预测概率通常是一个向量,其元素表示每个类别的预测概率。

公式

假设有 N N N 个样本,每个样本的类别数为 C C C。模型输出的概率分布为 p p p,实际标签为独热编码向量 y y y。分类交叉熵损失的公式如下:

Loss = − 1 N ∑ i = 1 N ∑ c = 1 C y i c log ⁡ ( p i c ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{ic} \log(p_{ic}) Loss=N1i=1Nc=1Cyiclog(pic)

其中:

  • y i c y_{ic} yic 是样本 i i i 的实际标签(独热编码),如果样本属于类别 c c c,则 y i c = 1 y_{ic}=1 yic=1;否则 y i c = 0 y_{ic}=0 yic=0
  • p i c p_{ic} pic 是模型预测样本 i i i 属于类别 c c c 的概率。

2. 稀疏分类交叉熵损失(Sparse Categorical Crossentropy)

定义:当标签以整数形式表示时使用。

原理:稀疏分类交叉熵损失与分类交叉熵损失的概念相似,不同之处在于它的标签是以整数形式表示的,而不是独热编码。这种表示形式可以更方便地用于多分类问题。

公式

假设有 N N N 个样本,模型输出的概率分布为 p p p,实际标签用整数 y y y 表示。稀疏分类交叉熵损失的公式如下:

$
\text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log(p_{i, y_i})
$

其中:

  • y i y_i yi 是样本 i i i 的实际标签(整数形式),表示样本 i i i 的类别。
  • p i , y i p_{i, y_i} pi,yi 是模型预测样本 i i i 属于真实类别的概率。

1. 准确率(Accuracy)

Accuracy = 正确分类的样本数量 总样本数量 \text{Accuracy} = \frac{\text{正确分类的样本数量}}{\text{总样本数量}} Accuracy=总样本数量正确分类的样本数量

  • 优点:简单易懂,适合于类别分布相对均衡的场景。
  • 缺点:在类别不平衡的情况下,准确率可能会误导。例如,如果95%的样本是某一类,仅凭准确率,模型可以只预测该类就能达到95%的准确率,但实际上并没有学习到有效的信息。

2. 混淆矩阵(Confusion Matrix)

混淆矩阵是一个非常有用的工具来可视化模型的性能。它展示了实际标签与模型预测标签之间的关系。一个二分类问题的混淆矩阵通常如下所示:

Predicted PositivePredicted Negative
Actual PositiveTrue Positive (TP)False Negative (FN)
Actual NegativeFalse Positive (FP)True Negative (TN)
  • True Positive (TP):正确预测为正样本的数量
  • False Positive (FP):错误预测为正样本的数量(实际为负样本)
  • True Negative (TN):正确预测为负样本的数量
  • False Negative (FN):错误预测为负样本的数量(实际为正样本)

3. 精确率(Precision)、召回率(Recall)与 F1 分数

这些指标特别适用于类别不平衡的情况,以下是它们的定义:

  • 精确率(Precision)

    定义:精确率是指在所有模型预测为正样本的结果中,实际上是真正正样本的比例。这一指标主要关注模型的假正率(False Positive, FP),即将负样本预测为正样本的错误数量。

    衡量模型预测为正的样本中有多少是真正的正样本,公式为:

    Precision = T P T P + F P \text{Precision} = \frac{TP}{TP + FP} Precision=TP+FPTP

  • 召回率(Recall)

    定义:召回率是指在所有实际为正的样本中,模型正确预测为正的比例。召回率主要关注模型的假负率(False Negative, FN),即将正样本预测为负样本的错误数量。

    衡量实际正样本中有多少被模型正确预测为正,公式为:

    Recall = T P T P + F N \text{Recall} = \frac{TP}{TP + FN} Recall=TP+FNTP

  • F1 分数:精确率和召回率的调和平均数,旨在找到两者之间的平衡,公式如下:

    F 1 = 2 × Precision × Recall Precision + Recall F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} F1=2×Precision+RecallPrecision×Recall

作用:

1. 卷积层(Convolutional Layer)

  • 原理:卷积层通过滑动窗口(也称为卷积核或过滤器)在输入图像上进行操作,计算局部区域的加权和。这使得网络能够提取图像中的特征,如边缘、纹理和形状。
  • 作用:通过卷积操作,卷积层能够有效地捕捉空间特征,从而减少参数数量,并且相比全连接层更能利用图像的空间结构。

2. 池化层(Pooling Layer)

  • 原理:池化层通常位于卷积层后面,其通过选择局部区域内的最值或均值来减少特征图的空间维度。最大池化选择一个窗口中的最大值,而平均池化则计算平均值。
  • 作用:降低特征图的维度,减少计算量,增加模型的计算速度,并在某种程度上防止过拟合,通过提炼出关键特征,增强模型的鲁棒性。

3. 全连接层(Dense Layer)

  • 原理:每个神经元与上一层的每个神经元相连接,计算所有输入的加权和并通过激活函数(如 ReLU)进行非线性变换。
  • 作用:全连接层往往是网络的最后一层,负责将所有提取的特征映射到最终的分类标签或回归输出,适用于较小的输入特征集。

4. 批归一化层(Batch Normalization Layer)

  • 原理:对每一批输入进行归一化,使得每个特征的均值接近于 0,方差接近于 1。这是通过计算当前批次的均值和标准差来实现的。
  • 作用:加速训练,稳定模型,提高收敛速度,降低对初始化和学习率的敏感性,通常用在激活函数之前。

5. Dropout层

  • 原理:在每个训练批次中,随机选择一定比例的神经元将其输出设置为零,这样有效地减少了模型的复杂度。
  • 作用:防止模型在训练数据上过拟合,增加模型的泛化能力。

6. 循环层(Recurrent Layer)

  • 原理:通过内部状态(记忆)和序列数据的输入进行连接,允许前一时刻的信息影响当前时刻的输出。LSTM 通过门机制来控制信息流,而 GRU 是 LSTM 的一个简化版本。
  • 作用:处理序列数据,比如时间序列分析和自然语言处理,能够记住长期依赖的信息。

7. 自注意力层(Attention Layer)

  • 原理:计算输入序列中每个元素对其他元素的重要性权重,然后为每个元素生成加权求和的输出。注意力机制帮助模型选择性地关注输入中的重要部分。
  • 作用:在处理序列数据时增强模型的注意力,特别是在自然语言处理和图像任务中提升性能。

8. 嵌入层(Embedding Layer)

  • 原理:将离散的输入(如词汇中的单词)映射到连续的向量空间中,使得相似的输入在向量空间中也相近。嵌入层通常会学习这些向量。
  • 作用:提高文本或离散空间数据的表示能力,将离散数据转化为能够参与深度学习模型的稠密特征向量。

9. 转置卷积层(Transposed Convolution Layer)

  • 原理:通过逆向的卷积操作(也被称为反卷积),在特征图上进行上采样,将低维空间的特征映射到更高维空间。通过在输入中插入零并应用卷积来实现。
  • 作用:用于生成更大输出的特征图,广泛应用于图像生成任务,如生成对抗网络(GAN)中。

10. 残差块(Residual Block)

  • 原理:通过添加跳跃连接,将输入直接加到输出中,让模型更容易学习到恒等映射。这样在训练深层网络时,可以有效缓解梯度消失问题。
  • 作用:在构建深层网络时提高了训练的有效性,减少了复杂度,使得网络可以更高效地学习,能够构建更深的网络而不会遇到性能下降的问题。

http://www.kler.cn/a/392993.html

相关文章:

  • 【安全编码】Web平台如何设计防止重放攻击
  • 深度学习中batch_size
  • 使用Python开发高级游戏:实现一个3D射击游戏
  • 纯Dart Flutter库适配HarmonyOS
  • 安宝特应用 | 美国OSHA扩展Vuzix AR眼镜应用,强化劳动安全与效率
  • 掌握命令行参数的艺术:Python的`argparse`库
  • Matplotlib库中show()函数的用法
  • uniapp内把视频mp4的base64保持到手机文件系统
  • 基于STM32单片机多路无线射频抢答器
  • 算法笔记/USACO Guide GOLD金组Graphs并查集Disjoint Set Union
  • dolphinscheduler
  • Rust编写的贪吃蛇小游戏源代码解读
  • docker pull 网络不通
  • 01.Linux网络设置、FTP
  • 数据驱动的智能决策:民锋科技的量化分析方案
  • golang项目三层依赖架构,自底向上;依赖注入trpc\grpc
  • ES6进阶知识一
  • 【启程Golang之旅】一站式理解Go语言中的gRPC
  • 无人机反制技术与方法:主动防御,被动防御技术原理详解
  • Spring Boot编程训练系统:技术实现与案例分析
  • Linux服务器下oracle自动rman备份的实现
  • 从“大吼”到“轻触”,防爆手机如何改变危险油气环境通信?
  • 【JavaScript】
  • CentOS下如何安装Nginx
  • 音频采样数据格式
  • YOLOv7-0.1部分代码阅读笔记-general.py