当前位置: 首页 > article >正文

Pytorch基本使用—参数初始化

深度学习模型参数初始化是指在训练深度神经网络时,对网络的权重和偏置进行初始化的过程。合适的参数初始化可以加速模型的收敛,并提高模型的性能。

✨ 1 基本介绍

在深度学习中,常用的参数初始化方法有以下几种:

  1. 零初始化(Zero Initialization):将所有权重和偏置初始化为0。然而,这种方法会导致所有神经元具有相同的输出,无法破坏对称性,因此不常用。
  2. 随机初始化(Random Initialization):将权重和偏置随机初始化为较小的随机值。这种方法可以打破对称性,但并不能保证初始化的权重和偏置能够适应网络的输入和输出分布。
  3. Xavier初始化(Xavier Initialization):根据每一层的输入维度和输出维度的大小来进行初始化。Xavier初始化方法根据激活函数的导数和输入输出的维度来调整初始化的范围,使得每一层的激活值保持在一个合适的范围内。
  4. He初始化(He Initialization):类似于Xavier初始化,但在计算权重的标准差时,将输入维度除以2。这是由于ReLU等非线性激活函数的性质导致的。
  5. 预训练初始化(Pretraining Initialization):在某些情况下,可以使用预训练的模型参数来初始化新的模型。例如,利用在大规模数据集上预训练的模型参数来初始化新任务的模型,可以加快模型的收敛速度。

需要注意的是,不同的参数初始化方法适用于不同的网络架构和激活函数。在实际应用中,需要根据具体情况选择适当的参数初始化方法。此外,还可以通过调整学习率和正则化等技巧来进一步优化训练过程。

✨ 2 零初始化(不常用)

🎈 2.1 理论

这里主要分析一下神经网络为什么不能将参数全部初始化为0
假设我们有下面的网络(为了简单,全部以线性函数计算):

第一层计算为:
在这里插入图片描述
第二层计算为:
在这里插入图片描述
以参数W11和W12的反向传播为例,梯度为:
在这里插入图片描述
因为都是0,则梯度为0,则参数更新停止。

✨ 3 Xavier初始化

第二节我们简单总结了为什么神经网络参数不能输出化为0,接下来我们讨论Xavier初始化。

🎃 3.1 介绍

在神经网络中,每个神经元的输入是由上一层的神经元输出和权重参数决定的。如果权重参数初始化过大,会导致输入值变得很大,从而使得激活函数的导数趋近于0,造成梯度消失问题。相反,如果权重参数初始化过小,会导致输入值变得很小,从而使得激活函数的导数趋近于1,造成梯度爆炸问题。

Xavier初始化通过根据网络层的输入和输出维度来合理地初始化权重参数,使得权重参数的方差保持在一个相对稳定的范围内。这样可以避免梯度消失和梯度爆炸问题,有助于提高网络的训练效果。

⛱️ 3.2 推导

这里我们以下列网络为例:
在这里插入图片描述

首先看前向传播
在这里插入图片描述
方差为(这里应用概率论相关计算公式,需要注意的是这里Xi经过归一化,E(Xi)=0)
在这里插入图片描述
如果Xi和Wi独立同分布,那么D(a1)的最终公式为
在这里插入图片描述
这里在这里插入图片描述代表着输入维度
而我们的目标是在这里插入图片描述,因此在这里插入图片描述


与上述计算方式一样,反向传播最终结果是在这里插入图片描述。只是这里,在这里插入图片描述是输出的维度大小。


但是在这里插入图片描述在这里插入图片描述一般情况下是不同的,因此,这里采取一种折中的方式
在这里插入图片描述,我们让在这里插入图片描述在区间[a, b]上均匀采样(均匀分布)
结合均匀分布方差公式在这里插入图片描述,解出Xavier初始化采样范围为在这里插入图片描述

🎈 3.3 构造

torch.nn.init.xavier_uniform_(tensor, a=0, b=1)
  1. tensor:需要填充的张量
  2. a:均匀分布的下界
  3. b:均匀分布的上界

☃️ 3.4 例子

w = torch.empty(3, 5)
nn.init.uniform_(w)

result:

tensor([[0.2116, 0.3085, 0.5448, 0.6113, 0.7697],
        [0.8300, 0.2938, 0.4597, 0.4698, 0.0624],
        [0.5034, 0.1166, 0.3133, 0.3615, 0.3757]])

http://www.kler.cn/a/37701.html

相关文章:

  • Spring Boot使用DataFreezer操作Aerospike
  • 确定Linux虚拟机需要安装哪个架构的应用
  • mongdb安全认证详解
  • Jenkins Pipline使用SonarScanner 检查 VUE、js 项目 中遇到的Bug
  • 渲染流程(上):HTML、CSS和JavaScript,是如何变成页面的?
  • Flutter:EasyLoading(loading加载、消息提示)
  • MySQL-分库分表详解(三)
  • 三菱PLC上位机测试
  • 100种思维模型之安全边际思维模型-92
  • 第五十八章 开发Productions - ObjectScript Productions - 测试和调试Production
  • 《计算机网络--自顶向下方法》第四章--网络层:数据平面
  • c# GDI+绘图的应用-多边形
  • 【C++刷题集】-- day4
  • 分布式锁与同步锁
  • MySQL---表数据高效率查询(简述)
  • C++中随机数的使用总结
  • C国演义 [第九章]
  • 拖动排序功能的实现 - 使用HTML、CSS和JavaScript
  • unbuntu 22.04 安装和卸载企业微信
  • XPath 文本匹配:正则表达式的应用与技巧