当前位置: 首页 > article >正文

Transformer 代码剖析7 - 词元嵌入(TokenEmbedding) (pytorch实现)

一、类定义与继承关系剖析

1.1 代码结构图示

神经网络基础模块
词嵌入基类
自定义词元嵌入
构造函数定义
基类初始化
词汇量参数
维度参数
填充标识参数

1.2 代码实现精讲

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
from torch import nn

class TokenEmbedding(nn.Embedding):
    """
    基于PyTorch实现的动态词元嵌入模块
    实现词元索引到高维向量的可学习映射
    核心功能:将离散的词元序列转换为连续的语义空间表示
    """
    
    def __init__(self, vocab_size, d_model):
        """
        词元嵌入构造器
        
        :param vocab_size: 词表容量(不同词元的总数)
        :param d_model: 嵌入维度(与Transformer模型维度一致)
        设计要点:
        - 继承nn.Embedding的矩阵运算特性
        - 固化填充索引为可训练参数
        - 保持维度与模型其他组件兼容
        """
        super(TokenEmbedding, self).__init__(
            vocab_size, # 嵌入数量 num_embeddings # 嵌入矩阵行数 = 词表大小
            d_model, # 嵌入维度 embedding_dim # 嵌入矩阵列数 = 模型维度
            padding_idx=1 # 填充符索引的特殊处理
        )

二、核心参数深度解读

2.1 参数矩阵可视化

假设词表容量vocab_size=10000,模型维度d_model=512时:

参数维度元素数量数学意义
weight[10000,512]5,120,000可训练的嵌入查询矩阵
padding_idxscalar1动态掩码位置标识

2.2 关键参数说明

1. vocab_size

  • 控制嵌入矩阵的行维度
  • 决定模型可处理的词元种类上限
  • 典型值域:BERT系列(~30000),GPT系列(~50000)

2. d_model

  • 控制嵌入向量的列维度
  • 与Transformer隐藏层维度严格对齐
  • 典型值域:512(原始论文)、768(BERT-base)、1024(大型模型)

3. padding_idx

  • 实现动态序列掩码的关键参数
  • 索引位置对应的梯度会被自动抑制
  • 防止填充符影响模型语义理解

三、运算过程分步推演

3.1 前向传播示例

输入序列:[3, 28, 1, 0] (1为填充符)

运算步骤:

1. 建立索引映射:

[[3],[[0.2, -0.5, ..., 1.2],  # 索引3的嵌入
 [28],[0.7, 1.1, ..., -0.3],   # 索引28的嵌入
 [1],[0.0, 0.0, ..., 0.0],    # 填充符固定值
 [0]][-0.9, 0.4, ..., 0.1]]   # 索引0的嵌入

2. 矩阵缩放(后续处理):

embeddings * sqrt(d_model)  # 维度对齐的数学技巧

3.2 梯度传播特性

  • 可微分性: 整个映射过程保持梯度通路
  • 参数更新: 通过反向传播调整嵌入矩阵
  • 特殊处理: padding_idx位置梯度始终为0

四、设计哲学解析

4.1 继承关系价值

TokenEmbedding
torch.nn.Embedding
torch.nn.Module
PyTorch基础设施

优势分析:

  • 复用性:继承矩阵运算和参数管理功能
  • 扩展性:保留自定义前向传播的可能性
  • 兼容性:无缝对接PyTorch生态工具

4.2 工程实践建议

1. 初始化技巧:

  • 默认采用均匀分布 U ( − 1 d m o d e l , 1 d m o d e l ) U(-\sqrt{\frac{1}{d_{model}}}, \sqrt{\frac{1}{d_{model}}}) U(dmodel1 ,dmodel1 )
  • 可扩展为Xavier/Kaiming初始化:
    # Xavier均匀初始化(默认)
    nn.init.xavier_uniform_(self.weight)
    
    # 特殊处理填充符
    self.weight.data[1].zero_()
    

2. 维度对齐策略:

# 与位置编码相加前的缩放
embeddings = embeddings * math.sqrt(d_model)

3. 混合精度训练:

# 自动转换为半精度
with autocast():
    embeddings = embedding_layer(input_ids)

4. 填充符处理机制:

  • 训练阶段自动跳过无效位置的计算
  • 推理阶段维持序列形状一致性

5. 计算复杂度分析:

  • 时间复杂度: O ( B ⋅ S ⋅ D ) O(B \cdot S \cdot D) O(BSD)
  • 空间复杂度: O ( V ⋅ D ) O(V \cdot D) O(VD)

完整实现细节可参考PyTorch中sparse.py 模块解析的相关文章(嵌入(Embedding)基类代码解析)或PyTorch官方Embedding文档。


http://www.kler.cn/a/569866.html

相关文章:

  • olmOCR:使用VLM解析PDF
  • Tattu发布全新行业无人机电池NEO系列,专为长续航设计
  • 【爬虫基础】第二部分 爬虫基础理论 P3/3
  • 一文掌握ADSL拨号代理的搭建方法,及详细使用
  • 数据结构--队列(C语言实现)
  • 一个非常好用便捷的web自动化爬虫工具Playwright
  • 大数据分析中的机器学习基础:从原理到实践
  • Dwall 动态壁纸自动匹配
  • 蓝桥杯深秋的苹果
  • 数据图表ScottPlot.WPF用法示例
  • HTTP 协议的发展历程:从 HTTP/1.0 到 HTTP/2.0
  • 【Linux】TCP协议
  • VScode C语言学习开发环境;运行提示“#Include错误,无法打开源文件stdio.h”
  • 计算机毕设-基于springboot的社团管理系统的设计与实现(附源码+lw+ppt+开题报告)
  • 小红的回文子串
  • 企业微信获取用户信息
  • MySQL增删改查(进阶)
  • 时序论文41 | Medformer:基于多粒度patch的时序分类模型
  • [含文档+PPT+源码等]精品基于Python实现的微信小程序的在线医疗咨询系统
  • 汽车智能钥匙低频PKE天线