当前位置: 首页 > article >正文

神经网络初始化 (init) 介绍

文章目录

    • 引言
    • 1. 初始化的重要性
      • 1.1 打破对称性
      • 1.2 控制方差
      • 1.3 加速收敛与提高泛化能力
    • 2. 常见的初始化方法及其应用场景
      • 2.1 Xavier/Glorot 初始化
      • 2.2 He 初始化
      • 2.3 正交初始化
      • 2.4 其他初始化方法
    • 3. 如何设置初始化
    • 4. 基于 BERT 的文本分类如何进行初始化
      • 4.1 项目背景
      • 4.2 模型构建
      • 4.3 模型训练与评估
      • 4.4 结果分析
    • 结论
    • 参考资料

引言

在深度学习的世界中,构建一个高效且性能优异的神经网络模型需要综合考虑多个因素。尽管选择合适的架构和优化算法至关重要,但权重初始化这一环节同样不容忽视。合适的初始化策略不仅能加速模型的收敛速度,提升训练稳定性,还能显著影响最终的模型性能。

1. 初始化的重要性

在深入探讨各种初始化方法之前,首先需要理解权重初始化在神经网络训练中的关键作用。

1.1 打破对称性

如果所有神经元的权重初始化为相同的值,网络在训练初期将无法学习到多样化的特征。这是因为在前向传播和反向传播过程中,每个神经元都会计算出相同的输出和梯度,导致它们在训练过程中同步更新,学习到相同的内容。打破对称性通过确保每个神经元的初始权重具有一定的随机性,使得每个神经元能够独立地探索不同的特征空间,从而提高模型的表达能力。

1.2 控制方差

在深层网络中,信号通过多层传播可能会逐渐放大或缩小,导致梯度消失或爆炸。这些问题会严重影响模型的训练效果,尤其是在反向传播阶段。合理的权重初始化能够保持每一层输出的方差稳定,确保信号在整个网络中均匀传播,避免梯度消失或爆炸,从而促进稳定的训练过程。

1.3 加速收敛与提高泛化能力

正确的初始化策略能够引导损失函数的优化过程,使其更容易找到好的局部最小值或全局最小值。这不仅能加快训练速度,还能提升模型的最终性能。此外,合理的初始化还能够帮助模型更快地进入一个具备良好泛化能力的参数区域,提升其在未见数据上的表现。

2. 常见的初始化方法及其应用场景

根据不同的激活函数和网络架构,存在多种权重初始化方法。以下是几种常见且有效的初始化策略及其适用场景。

2.1 Xavier/Glorot 初始化

适用场景:适用于激活函数为 Sigmoid 或 Tanh 的神经网络。

原理:Xavier 初始化通过维持每一层输入和输出信号的方差一致,防止梯度在传播过程中逐渐消失或爆炸。具体来说,它根据输入和输出神经元的数量来设定权重的初始化范围,通常采用均匀分布或正态分布。

示例

import torch.nn as nn

# Xavier 初始化示例
linear = nn.Linear(in_features=256, out_features=128)
nn.init.xavier_uniform_(linear.weight)

2.2 He 初始化

适用场景:专为 ReLU 及其变体设计的神经网络。

原理:He 初始化考虑到 ReLU 激活函数的非负输出特性,调整了初始化时权重的方差,使其更适合 ReLU 的特性。这样可以有效地保持信号在前向传播过程中的标准差不变,避免梯度消失。

示例

import torch.nn as nn

# He 初始化示例
conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')

2.3 正交初始化

适用场景:特别适合循环神经网络(RNNs)和非常深的前馈网络。

原理:正交初始化通过确保权重矩阵的列彼此正交且具有单位长度,有效地防止了梯度消失或爆炸的问题。对于 RNN 来说,正交矩阵能够保持序列数据的长时间依赖关系,提升模型的表现。

示例

import torch.nn as nn

# 正交初始化示例
linear = nn.Linear(in_features=256, out_features=256)
nn.init.orthogonal_(linear.weight)

2.4 其他初始化方法

除了上述主要方法外,还有一些其他初始化策略,尽管在现代实践中使用较少,但在特定场景下也有其应用价值。

  • 零初始化:将权重初始化为零。这种方法通常不推荐用于隐藏层,因为会导致对称性问题,但可以用于初始化偏置项。

    import torch.nn as nn
    
    # 零初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.constant_(linear.weight, 0)
    
  • 随机初始化:使用标准正态分布或均匀分布随机初始化权重。虽然简单,但在深层网络中可能导致梯度问题。

    import torch.nn as nn
    
    # 随机初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.normal_(linear.weight, mean=0.0, std=0.02)
    
  • 稀疏初始化:初始化部分权重为非零,其他为零。适用于希望网络具有稀疏连接的场景。

    import torch.nn as nn
    
    # 稀疏初始化示例
    linear = nn.Linear(in_features=256, out_features=128)
    nn.init.sparse_(linear.weight, sparsity=0.1)
    

3. 如何设置初始化

在深度学习框架如 PyTorch 中,权重初始化可以通过自定义初始化方法或直接利用内置函数来实现。以下以一个简单的卷积神经网络(CNN)为例,展示如何在构造函数中应用不同的初始化策略。

import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()

        # 定义网络层
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(32 * 16 * 16, num_classes)  # 假设输入图像大小为32x32

        # 初始化权重
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 使用 He 初始化
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                # 批归一化层权重初始化为1,偏置初始化为0
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # 使用 Xavier 初始化
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

解析

  • 卷积层(Conv2d):采用 He 初始化,适合 ReLU 激活函数,确保信号在前向传播中的稳定。

  • 批归一化层(BatchNorm2d):权重初始化为1,偏置初始化为0,保证初始状态下批归一化层的标准化效果。

  • 全连接层(Linear):采用 Xavier 初始化,适合 Sigmoid 或 Tanh 激活函数,保持输入和输出信号的方差一致。

4. 基于 BERT 的文本分类如何进行初始化

为了更深入地理解权重初始化的实际应用,本文将通过一个具体的文本分类任务,展示如何在预训练模型 BERT 的基础上进行初始化和微调。

4.1 项目背景

文本分类是自然语言处理中的基础任务之一,广泛应用于情感分析、垃圾邮件检测、话题分类等场景。近年来,预训练语言模型如 BERT(Bidirectional Encoder Representations from Transformers)因其强大的语言理解能力,成为文本分类任务的首选基础模型。

4.2 模型构建

以下代码展示了如何构建一个基于 BERT 的文本分类器,并在新增加的分类层上应用权重初始化。

from transformers import BertModel, BertTokenizer
import torch.nn as nn

class BertForTextClassification(nn.Module):
    def __init__(self, num_labels=2, dropout_rate=0.3):
        super(BertForTextClassification, self).__init__()
        # 加载预训练的 BERT 模型
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')        

        # 冻结 BERT 模型的参数以加快训练速度(可选)
        for param in self.bert.parameters():
            param.requires_grad = False       

        # 定义分类头
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)        

        # 初始化分类层的权重
        self._initialize_weights()

    def _initialize_weights(self):
        # 使用 Xavier 初始化分类层权重
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        # 获取 BERT 的输出
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        # 取 [CLS] 标记的输出作为分类依据
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        return logits

关键步骤解析

  1. 加载预训练模型

    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    

    通过 transformers 库加载预训练的 BERT 模型及其对应的分词器。

  2. 冻结预训练模型参数(可选)

    for param in self.bert.parameters():
        param.requires_grad = False
    

    冻结 BERT 模型的参数,仅训练新增加的分类层,能够显著减少训练时间和计算资源消耗,适用于数据量较小的场景。

  3. 定义分类头

    self.dropout = nn.Dropout(dropout_rate)
    self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    

    使用 dropout 层防止过拟合,并通过全连接层将 BERT 的输出映射到分类标签空间。

  4. 初始化分类层权重

    self._initialize_weights()
    

    为新增加的分类层应用 Xavier 初始化,确保其在训练开始时具有良好的表现。

4.3 模型训练与评估

在训练过程中,合理的初始化策略能够帮助模型更快地收敛,并在有限的训练迭代中达到较好的性能。以下是一个简单的训练和评估流程示例:

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from sklearn.metrics import accuracy_score

# 假设已定义好文本数据集和数据加载器
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length    

    def __len__(self):
        return len(self.texts)    

    def __getitem__(self, idx):
        encoding = self.tokenizer.encode_plus(
            self.texts[idx],
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# 初始化数据集和数据加载器
train_dataset = TextDataset(train_texts, train_labels, model.tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = TextDataset(val_texts, val_labels, model.tokenizer)
val_loader = DataLoader(val_dataset, batch_size=32)

# 初始化模型、优化器和损失函数
model = BertForTextClassification(num_labels=2)
optimizer = AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# 训练循环
for epoch in range(3):  # 假设训练3个epoch
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()    

    # 评估
    model.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1} - Validation Accuracy: {acc:.4f}")

关键点

  • 数据预处理:使用 BERT 的分词器将文本转换为模型可接受的输入格式,包括 input_idsattention_mask

  • 冻结与微调:根据具体需求,可以选择冻结 BERT 的部分或全部参数,仅训练新增加的层,或进行全模型微调。

  • 优化器与损失函数:使用 AdamW 优化器和交叉熵损失函数,适用于分类任务。

4.4 结果分析

通过合理的权重初始化和预训练模型的优势,基于 BERT 的文本分类器在多个标准数据集上表现出色。例如,在情感分析任务中,冻结 BERT 参数并仅训练分类层的模型,能够在较短的训练时间内达到接近全模型微调的性能,同时显著减少计算资源的消耗。

结论

权重初始化在神经网络训练中扮演着至关重要的角色。合适的初始化策略不仅能够打破对称性,控制信号的方差,还能加速模型的收敛,提高泛化能力。本文系统地介绍了几种常见的初始化方法及其适用场景,并通过基于 BERT 的文本分类示例,展示了如何在实际项目中应用这些初始化策略。

参考资料

  • Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks.

  • He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.

  • BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

  • PyTorch 官方文档 - 权重初始化


http://www.kler.cn/a/503697.html

相关文章:

  • 抓包之使用抓包来验证TCP三次握手
  • Anaconda安装(2024最新版)
  • Android SystemUI——基础简介(一)
  • 源码编译安装httpd 2.4,提供系统服务管理脚本并测试(两种方法实现)
  • thinkphp 5.0 结合redis 做延迟队列,队列无法被消费
  • MIUI显示/隐藏5G开关的方法,信号弱时开启手机Wifi通话方法
  • 你喜欢用什么编辑器?
  • Java中的动态代理是什么?如何实现?
  • Spark vs Flink分布式数据处理框架的全面对比与应用场景解析
  • P1图文解析:初识算法和数据结构
  • [Deep Learning] Anaconda+CUDA+CuDNN+Pytorch(GPU)环境配置-2025
  • day09_kafka高级
  • 《蜜蜂路线》
  • JavaWeb开发 - Filter过滤器详解
  • 【OJ刷题】同向双指针问题3
  • 【机器学习】P1 机器学习绪论
  • vite之---为什么选vite
  • seleniun 自动化程序,python编程 我监控 chrome debug数据后 ,怎么获取控制台的信息呢
  • 使用Docker模拟PX4固件的无人机用于辅助地面站开发
  • MCP Server开发的入门教程(python和pip)
  • 服务器宕机原因?该怎么处理?
  • AOSP 14及以上userdebug无法调试的问题
  • wireshark开启对https密文抓包
  • c语言注意事项(不断完善版)
  • 大师课程:专业角色AE+AI动画动态设计关键帧学院视频课程 Key Frame Academy – Character Animation Launchpad
  • Java全栈项目-学生创新创业项目管理系统