前馈神经网络 - 参数学习(梯度下降法 - 多分类任务)
之前的博文中,对于前馈神经网络,我们学习过基于二分类任务的参数学习,本文我们来学习两个多分类任务的参数学习的例子,来进一步加深对反向传播算法的理解。
例子1:前馈神经网络在多分类任务中的参数学习(梯度下降法)
以下通过一个 3分类任务 的具体例子,详细说明前馈神经网络如何使用梯度下降法进行参数学习。假设网络结构如下:
-
输入层:2个神经元(输入特征 x1,x2)
-
隐藏层:3个神经元(使用 ReLU 激活函数)
-
输出层:3个神经元(使用 Softmax 激活函数)
1. 网络参数初始化
假设输入样本为 x=[1,0],真实标签为类别 2(标签编码为 one-hot 向量 y=[0,0,1])。
参数定义:
-
输入层到隐藏层:
-
隐藏层到输出层:
2. 前向传播
-
隐藏层计算:
-
线性变换:
-
ReLU 激活:
-
-
输出层计算:
-
线性变换:
-
Softmax 激活(转换为概率):
-
3. 损失计算(交叉熵损失)
真实标签为类别 2(y=[0,0,1]),预测概率为 y^≈[0.31,0.20,0.49]:
4. 反向传播计算梯度
输出层梯度
-
Softmax + 交叉熵的梯度简化:
-
参数梯度:
隐藏层梯度
-
误差信号传播:
-
ReLU 导数:
-
上游误差:
-
逐元素相乘:
-
-
参数梯度:
5. 参数更新(学习率 η=0.1)
-
输出层参数:
-
隐藏层参数:
6. 验证更新后的预测
更新参数后,重新进行前向传播:
-
隐藏层输出可能更接近真实类别 2,损失 L 应减小。例如,若新的预测概率为 [0.25,0.15,0.60],则损失为 −log(0.60)≈0.511<0.713,表明参数学习有效。
关键总结
-
Softmax + 交叉熵:
-
多分类任务的标准组合,梯度计算简化为
。
-
-
ReLU 导数特性:
-
激活导数为 0 或 1,加速计算并缓解梯度消失问题。
-
-
梯度下降步骤:
-
通过链式法则逐层计算梯度,参数沿负梯度方向更新。
-
-
实际应用注意点:
-
学习率需调参(过大震荡,过小收敛慢)。
-
参数初始化影响收敛(如 Xavier 初始化)。
-
例子2:一个简单的多层感知器(MLP)
下面给出一个基于多分类任务的前馈神经网络参数学习过程,展示如何使用梯度下降法(GD)结合反向传播计算梯度,逐步优化参数。我们以一个简单的多层感知器(MLP)来处理三分类问题为例。
1. 网络结构设定
假设我们的任务是将输入样本分为三个类别(类别1、类别2、类别3)。网络结构如下:
- 输入层:假设输入向量 x∈R^d(例如 d=4)。
- 隐藏层:设有一层隐藏层,包含 h 个神经元,激活函数使用 ReLU。
- 输出层:有 3 个神经元,对应三个类别,激活函数采用 Softmax,将输出转换为概率分布。
具体数学描述:
- 隐藏层:
- 输出层:
Softmax 的定义为:
2. 损失函数
对于多分类任务,我们通常采用多类别交叉熵损失函数。假设真实标签 y 使用 one-hot 编码,交叉熵损失为:
3. 具体例子
网络参数设定(示例数值)
假设输入维度 d=4,隐藏层神经元数量 h=3:
- 输入 x = [1.0, 0.5, -1.0, 2.0]^T。
- 隐藏层权重:
隐藏层偏置:
- 输出层权重:
输出层偏置:
假设真实标签为类别3,即 one-hot 编码 y = [0, 0, 1]^T。
前向传播计算
隐藏层计算:
- 计算
逐个神经元计算:
-
神经元1:
-
神经元2:
-
神经元3:
得到 .
2. 通过 ReLU 激活函数计算 a^{(1)}:
输出层计算:
- 计算
对每个类别计算:
-
类别1:
-
类别2:
-
类别3:
2. 通过 Softmax 激活函数计算预测概率:
此时,模型预测概率为:
- 类别1:42.1%,
- 类别2:27.7%,
- 类别3:30.2%。
假设真实标签为类别3,则 one-hot 编码 y = [0,0,1]^T。
4. 损失计算
采用多类别交叉熵损失函数:
由于 y=[0,0,1] ,损失为:
5. 反向传播与参数更新(简要描述)
-
输出层梯度:
-
隐藏层梯度:
-
参数更新: 使用梯度下降法(例如学习率 η),更新各层参数:
经过多次迭代和大量样本训练,网络参数逐渐调整使得损失函数最小化,模型预测准确率不断提升。
总结
利用梯度下降法对前馈神经网络进行参数学习的过程包括:
- 前向传播:将输入数据通过网络各层计算,得到预测概率。
- 损失计算:利用多类别交叉熵损失函数衡量预测与真实标签之间的差距。
- 反向传播:使用链式法则,从输出层到隐藏层逐层计算梯度。
- 参数更新:依据计算得到的梯度,采用梯度下降(或其变种)更新各层权重和偏置。
通过具体的多分类任务示例(例如一个三类别分类问题),我们可以看到如何从输入、前向传播、损失计算、反向传播到参数更新的整个流程,最终实现神经网络参数的优化和任务性能的提升。