pytorch常用参数初始化
一、基础初始化方法
1. 全零初始化(Zero Initialization)
• 方法:权重初始化为0(不推荐用于隐藏层)
• 问题:导致所有神经元对称更新,失去多样性
• PyTorch代码:
python nn.init.zeros_(weight)
2. 随机初始化(Random Initialization)
• 均匀分布:U(-a, a)
python nn.init.uniform_(weight, a=-0.1, b=0.1)
• 正态分布:N(0, std)
python nn.init.normal_(weight, mean=0.0, std=0.01)
二、经典初始化方法
3. Xavier/Glorot 初始化
• 核心思想:保持输入输出方差一致,适用于tanh/sigmoid激活
• 数学公式:
std
=
2
n
i
n
+
n
o
u
t
(
tanh
)
\text{std} = \sqrt{\frac{2}{n_{in} + n_{out}}} \quad (\text{tanh})
std=nin+nout2(tanh)
• PyTorch代码:
```python
# Xavier均匀分布(默认)
nn.init.xavier_uniform_(weight, gain=nn.init.calculate_gain('tanh'))
# Xavier正态分布
nn.init.xavier_normal_(weight, gain=1.0)
```
4. He/Kaiming 初始化
• 核心思想:修正ReLU族的负区间影响,适用于ReLU/LeakyReLU
• 数学公式:
std
=
2
n
i
n
(
ReLU
)
\text{std} = \sqrt{\frac{2}{n_{in}}} \quad (\text{ReLU})
std=nin2(ReLU)
• PyTorch代码:
```python
# Kaiming均匀分布(推荐默认)
nn.init.kaiming_uniform_(weight, mode='fan_in', nonlinearity='relu')
# Kaiming正态分布
nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='leaky_relu', a=0.1)
```
三、高级初始化方法
5. 正交初始化(Orthogonal Initialization)
• 核心思想:保持输入输出空间的正交性,适用于RNN/LSTM
• 数学原理:权重矩阵满足 ( W^T W = I )
• PyTorch代码:
python nn.init.orthogonal_(weight, gain=1.0)
6. 稀疏初始化(Sparse Initialization)
• 方法:随机将部分权重设为0,打破对称性
• PyTorch代码:
python nn.init.sparse_(weight, sparsity=0.1, std=0.01)
7. 截断正态分布(Truncated Normal)
• 特点:限制采样范围在±2std内,避免极端值
• PyTorch代码:
python nn.init.trunc_normal_(weight, mean=0.0, std=0.02, a=-0.04, b=0.04)
四、特殊场景初始化
8. 残差网络初始化
• 核心技巧:将最后一个BN层的权重初始化为0
python nn.init.constant_(bn.weight, 0) # 保证初始残差块恒等映射
9. 自注意力初始化
• Query/Key初始化:缩小点积范围
python nn.init.xavier_uniform_(qkv_weight, gain=1/math.sqrt(dim))
10. 预训练模型加载
• HuggingFace最佳实践:
python model = BertModel.from_pretrained('bert-base-uncased') model.init_weights() # 对新增层初始化
五、初始化方法对比表
方法 | 核心思想 | 适用场景 | PyTorch函数 |
---|---|---|---|
Xavier | 输入输出方差一致 | tanh/sigmoid | xavier_uniform_ |
He/Kaiming | 修正ReLU负区间 | ReLU/LeakyReLU | kaiming_normal_ |
正交初始化 | 保持矩阵正交性 | RNN/LSTM | orthogonal_ |
截断正态 | 避免极端值 | Transformer/ViT | trunc_normal_ |
稀疏初始化 | 打破对称性 | 大规模稀疏网络 | sparse_ |
六、最佳实践建议
- 默认首选:对CNN/MLP使用
Kaiming初始化
,对RNN使用正交初始化
- 激活函数适配:根据激活函数选择
nonlinearity
参数 - 模式选择:
•fan_in
:推荐前向传播时保持方差
•fan_out
:适合反向传播时保持梯度 - 组合使用:初始化后配合BatchNorm使用效果更佳
- 可视化验证:使用直方图观察参数分布是否符合预期
代码示例:综合初始化
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
elif isinstance(m, nn.LSTM):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.constant_(param, 0)
model.apply(init_weights)