当前位置: 首页 > article >正文

pytorch常用参数初始化


一、基础初始化方法

1. 全零初始化(Zero Initialization)

方法:权重初始化为0(不推荐用于隐藏层)
问题:导致所有神经元对称更新,失去多样性
PyTorch代码
python nn.init.zeros_(weight)

2. 随机初始化(Random Initialization)

均匀分布U(-a, a)
python nn.init.uniform_(weight, a=-0.1, b=0.1)
正态分布N(0, std)
python nn.init.normal_(weight, mean=0.0, std=0.01)


二、经典初始化方法

3. Xavier/Glorot 初始化

核心思想:保持输入输出方差一致,适用于tanh/sigmoid激活
数学公式
std = 2 n i n + n o u t ( tanh ) \text{std} = \sqrt{\frac{2}{n_{in} + n_{out}}} \quad (\text{tanh}) std=nin+nout2 (tanh)

PyTorch代码

 ```python
 # Xavier均匀分布(默认)
 nn.init.xavier_uniform_(weight, gain=nn.init.calculate_gain('tanh'))
 
 # Xavier正态分布
 nn.init.xavier_normal_(weight, gain=1.0)
 ```
4. He/Kaiming 初始化

核心思想:修正ReLU族的负区间影响,适用于ReLU/LeakyReLU
数学公式
std = 2 n i n ( ReLU ) \text{std} = \sqrt{\frac{2}{n_{in}}} \quad (\text{ReLU}) std=nin2 (ReLU)

PyTorch代码

 ```python
 # Kaiming均匀分布(推荐默认)
 nn.init.kaiming_uniform_(weight, mode='fan_in', nonlinearity='relu')
 
 # Kaiming正态分布
 nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='leaky_relu', a=0.1)
 ```

三、高级初始化方法

5. 正交初始化(Orthogonal Initialization)

核心思想:保持输入输出空间的正交性,适用于RNN/LSTM
数学原理:权重矩阵满足 ( W^T W = I )
PyTorch代码
python nn.init.orthogonal_(weight, gain=1.0)

6. 稀疏初始化(Sparse Initialization)

方法:随机将部分权重设为0,打破对称性
PyTorch代码
python nn.init.sparse_(weight, sparsity=0.1, std=0.01)

7. 截断正态分布(Truncated Normal)

特点:限制采样范围在±2std内,避免极端值
PyTorch代码
python nn.init.trunc_normal_(weight, mean=0.0, std=0.02, a=-0.04, b=0.04)


四、特殊场景初始化

8. 残差网络初始化

核心技巧:将最后一个BN层的权重初始化为0
python nn.init.constant_(bn.weight, 0) # 保证初始残差块恒等映射

9. 自注意力初始化

Query/Key初始化:缩小点积范围
python nn.init.xavier_uniform_(qkv_weight, gain=1/math.sqrt(dim))

10. 预训练模型加载

HuggingFace最佳实践
python model = BertModel.from_pretrained('bert-base-uncased') model.init_weights() # 对新增层初始化


五、初始化方法对比表

方法核心思想适用场景PyTorch函数
Xavier输入输出方差一致tanh/sigmoidxavier_uniform_
He/Kaiming修正ReLU负区间ReLU/LeakyReLUkaiming_normal_
正交初始化保持矩阵正交性RNN/LSTMorthogonal_
截断正态避免极端值Transformer/ViTtrunc_normal_
稀疏初始化打破对称性大规模稀疏网络sparse_

六、最佳实践建议

  1. 默认首选:对CNN/MLP使用Kaiming初始化,对RNN使用正交初始化
  2. 激活函数适配:根据激活函数选择nonlinearity参数
  3. 模式选择
    fan_in:推荐前向传播时保持方差
    fan_out:适合反向传播时保持梯度
  4. 组合使用:初始化后配合BatchNorm使用效果更佳
  5. 可视化验证:使用直方图观察参数分布是否符合预期

代码示例:综合初始化

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

model.apply(init_weights)

在这里插入图片描述


http://www.kler.cn/a/577800.html

相关文章:

  • es优化方面
  • LeetCode1137 第N个泰波那契数
  • C++入门——函数重载
  • linux 命令sed
  • QT中使用C++调用 python脚本
  • 驱动开发系列45 - Linux 显卡KMD驱动代码分析(六)- 显卡驱动与OS接口
  • 小迪安全-27-php开发,tp框架,路由访问,对象操作,内置过滤,核心漏洞
  • 2.2.1 网络原理-posix api
  • #9 【code】实现扩散模型的一个jupyter notebook
  • PX4中的DroneCAN的实现库Libuavcan及基础功能示例
  • 【Hadoop】什么是Zookeeper?如何理解Zookeeper?
  • 记录小白使用 Cursor 开发第一个微信小程序(一):注册账号及下载工具(250308)
  • Dubbo+Zookeeper
  • 从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(十一) 实现服务端和客户端socketio 连接
  • 金融合规测试:金融系统稳健运行的“定海神针“
  • 物联网通过数字孪生技术实现设备状态的实时仿真和优化
  • 每日一练之移除链表元素
  • spring IOC(实现原理)
  • 基于自定义Tomcat实现资源访问的完整指南
  • 探索React:构建现代前端应用的强大框架