当前位置: 首页 > article >正文

【HuggingFace Transformers】BertSelfOutput 和 BertOutput源码解析

BertSelfOutput 和 BertOutput源码解析

  • 1. 介绍
    • 1.1 共同点
      • (1) 残差连接 (Residual Connection)
      • (2) 层归一化 (Layer Normalization)
      • (3) Dropout
      • (4) 线性变换 (Linear Transformation)
    • 1.2 不同点
      • (1) 处理的输入类型
      • (2) 线性变换的作用
      • (3) 输入的特征大小
  • 2. 源码解析
    • 2.1 BertSelfOutput 源码解析
    • 2.2 BertOutput 源码解析

1. 介绍

BertSelfOutputBertOutputBERT 模型中两个相关但不同的模块。它们在功能上有许多共同点,但也有一些关键的不同点。以下通过共同点和不同点来介绍它们。

1.1 共同点

BertSelfOutputBertOutput 都包含残差连接、层归一化、Dropout 和线性变换,并且这些操作的顺序相似。

(1) 残差连接 (Residual Connection)

两个模块都应用了残差连接,即将模块的输入直接与经过线性变换后的输出相加。这种结构可以帮助缓解深层神经网络中的梯度消失问题,使信息更直接地传递,保持梯度流动顺畅。

(2) 层归一化 (Layer Normalization)

在应用残差连接后,两个模块都使用层归一化 (LayerNorm) 来规范化输出。这有助于加速训练,稳定网络性能,并减少内部分布变化的问题。

(3) Dropout

两个模块都包含一个 Dropout 层,用于随机屏蔽一部分神经元的输出,增强模型的泛化能力,防止过拟合。

(4) 线性变换 (Linear Transformation)

两个模块都包含一个线性变换 (dense 层)。这个线性变换用于调整数据的维度,并为后续的残差连接和层归一化做准备。

1.2 不同点

BertSelfOutput 专注于处理自注意力机制的输出,而 BertOutput 则处理前馈神经网络的输出。它们的输入特征维度也有所不同,线性变换的作用在两个模块中也略有差异。

(1) 处理的输入类型

  • BertSelfOutput:处理自注意力机制 (BertSelfAttention) 的输出。它关注的是如何将注意力机制生成的特征向量与原始输入结合起来。
  • BertOutput:处理的是前馈神经网络的输出。它将经过注意力机制处理后的特征进一步加工,并整合到当前层的最终输出中。

(2) 线性变换的作用

  • BertSelfOutput:线性变换的作用是对自注意力机制的输出进行进一步的变换和投影,使其适应后续的处理流程。
  • BertOutput:线性变换的作用是对前馈神经网络的输出进行变换,使其与前一层的输出相结合,并准备传递到下一层。

(3) 输入的特征大小

  • BertSelfOutput:输入和输出的特征维度保持一致,都是 BERT 模型的隐藏层大小 (hidden_size)。
  • BertOutput:输入的特征维度是中间层大小 (intermediate_size),输出则是 BERT 模型的隐藏层大小 (hidden_size)。这意味着 BertOutput 的线性变换需要将中间层的维度转换回隐藏层的维度。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertSelfOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:27

import torch
from torch import nn


class BertSelfOutput(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 定义线性变换层,将自注意力输出映射到 hidden_size 维度
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化
        self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 对自注意力机制的输出进行线性变换
        hidden_states = self.dropout(hidden_states)  # Dropout操作
        hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化
        return hidden_states

2.2 BertOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/8/22 15:41

import torch
from torch import nn


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)  # 定义线性变换层,将前馈神经网络输出从 intermediate_size 映射到 hidden_size
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化
        self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 对前馈神经网络的输出进行线性变换
        hidden_states = self.dropout(hidden_states)  # Dropout操作
        hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化
        return hidden_states

http://www.kler.cn/news/283837.html

相关文章:

  • 开源个人云存储管理专家:Cloudreve
  • 零基础入门转录组数据分析——单基因ROC分析
  • Leetcode Java学习记录——动态规划基础_3
  • 尚硅谷大数据技术-Kafka视频教程-笔记01【Kafka 入门】
  • 8月30复盘日记
  • k8s-pod 实战四 什么是 Kubernetes Pod?如何在生产环境中使用它?(学习专场,实战就看这一篇就够了)
  • 把http网站变成https
  • WPF 使用PdfiumViewer实现PDF预览与打印
  • RabbitMQ本地Ubuntu系统环境部署与无公网IP远程连接服务端实战演示
  • element input限制输入框只能输入数字
  • 深入解析:文本分析模型性能评估的艺术与科学
  • 浅谈对分布式锁的认识
  • React中实现antd自定义图标,鼠标悬浮变色
  • Java算法之BogoSort(或称为Permutation Sort、Monkey Sort)
  • day39(了解docker-compose,docker-compose编排容器,配置harbor服务)
  • PneumoLLM: 利用大语言模型的力量进行尘肺病诊断| 文献速递-大模型与多模态诊断阿尔茨海默症与帕金森疾病应用
  • 数据的时光机:SQL中实现数据版本控制的策略
  • Go微服务开发框架DMicro的设计思路
  • Scala之高阶面向对象编程
  • 【NCom】:通用负压退火方法构建超高负载单原子催化剂库
  • Python 3.11 从入门到实战1(环境准备)
  • 鸿蒙XComponent组件的认识
  • FastJson序列化驼峰-下划线转换问题踩坑记录
  • 基于Spring Boot的文字识别系统
  • 逆波兰表达式求值
  • 【面试经验】华为产品行销经理面经
  • 数据赋能(187)——开发:数据产品——概述、关注焦点
  • 超详细Git的基本命令使用(三)
  • C++入门基础知识43——【关于C++循环】
  • Golang | Leetcode Golang题解之第384题打乱数组