当前位置: 首页 > article >正文

Transformer 代码剖析6 - 位置编码 (pytorch实现)

一、位置编码的数学原理与设计思想

1.1 核心公式解析

位置编码采用正弦余弦交替编码方案:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

式中:

  • p o s pos pos:当前词在序列中的绝对位置
  • i i i:特征维度的索引( 0 ≤ i < d m o d e l / 2 0 \leq i < d_{model}/2 0i<dmodel/2
  • 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel:频率控制项,形成指数衰减的频率分布

1.2 设计优势分析

1. 绝对位置感知: 每个位置生成唯一编码模式
2. 相对位置建模: 通过三角函数加法公式可推导任意两个位置的关联度
3. 多频特征捕捉: 不同频率的正余弦波组合形成丰富的表征空间
4. 值域归一化: 所有编码值分布在[-1,1]区间,与词嵌入维度保持数值一致性
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)
(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)

二、代码架构与执行流程

2.1 类结构设计

PositionalEncoding
__init__构造函数
创建零矩阵
配置梯度策略
构建位置索引
生成维度索引
计算正弦编码
计算余弦编码
forward前向传播
获取输入尺寸
返回截断编码

2.2 核心代码模块

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super().__init__()
        # 编码矩阵初始化(关键参数说明)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 冻结梯度计算
        
        # 位置索引构建(维度变换演示)
        pos = torch.arange(0, max_len, device=device).float().unsqueeze(dim=1)
        
        # 维度索引生成(步长控制逻辑)
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        
        # 编码计算过程(数学实现)
        self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

三、逐行代码深度解析

3.1 构造函数解析

def __init__(self, d_model, max_len, device):
    super(PositionalEncoding, self).__init__()
  • 功能说明:继承PyTorch模块基类,初始化可训练参数
  • 参数详解:
    • d_model:编码维度(需与词嵌入维度一致)
    • max_len:预计算的最大序列长度(如512对应BERT标准配置)
    • device:硬件加速配置(实现跨平台兼容)
    self.encoding = torch.zeros(max_len, d_model, device=device)
    self.encoding.requires_grad = False
  • 设计意图:创建静态编码矩阵,避免反向传播计算
  • 内存优化:通过requires_grad=False节省显存占用
  • 维度说明:矩阵形状为[max_len, d_model],例如max_len=512时生成512x512矩阵
    pos = torch.arange(0, max_len, device=device)
    pos = pos.float().unsqueeze(dim=1)
  • 位置索引构建:生成[0,1,…,max_len-1]的连续位置序列
  • 维度变换:通过unsqueeze将1D张量转换为2D(max_len,1),便于广播计算
    _2i = torch.arange(0, d_model, step=2, device=device).float()
  • 步长控制:step=2确保交替访问奇偶索引
  • 数值范围:当d_model=512时,生成[0,2,4,…,510]的索引序列
    self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
    self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))
  • 分片赋值:通过0::21::2实现奇偶列交替填充
  • 频率控制:10000 (_2i/d_model)生成指数衰减的频率系数

3.2 前向传播解析

def forward(self, x):
    batch_size, seq_len = x.size()
    return self.encoding[:seq_len, :]
  • 动态适配:根据实际输入序列长度截取编码
  • 广播机制:自动扩展编码矩阵到批次维度(无需显式复制)
  • 数值叠加:后续与词嵌入进行element-wise相加操作

四、张量运算可视化演示

4.1 示例参数配置

假设:

  • d_model = 4
  • max_len = 3
  • device = 'cpu'

4.2 计算过程推演

步骤1:生成位置索引

pos = [[0],
       [1],
       [2]]  # shape (3,1)

步骤2:创建维度索引

_2i = [0, 2]  # d_model=4时step=2生成

步骤3:计算频率项

频率项 = 10000^( (0/4), (2/4) ) 
       = [1, 10000^0.5] 
       ≈ [1, 100]

步骤4:计算位置编码

sin项:
pos / [1, 100] = [[0/1, 0/100],
                 [1/1, 1/100],
                 [2/1, 2/100]]
               = [[0, 0],
                  [1, 0.01],
                  [2, 0.02]]
sin值:
[[0, 0],
 [0.8415, 0.00999983],
 [0.9093, 0.01999867]]

cos项计算同理...

最终编码矩阵:

PE = [
  [sin(0), cos(0), sin(0), cos(0)],      # 位置0
  [sin(1), cos(0.01), sin(1), cos(0.01)],# 位置1
  [sin(2), cos(0.02), sin(2), cos(0.02)] # 位置2
]

五、工程实践与优化策略

5.1 配置参数建议

  1. max_len设定:应大于训练数据最大序列长度20%
  2. 设备兼容性:通过device参数统一管理计算设备
  3. 混合精度训练:可将编码矩阵转为half精度

5.2 性能优化技巧

  1. 预计算缓存:提前生成编码矩阵避免运行时计算
  2. 内存映射:对超长序列使用内存映射文件
  3. 稀疏矩阵:对长文本场景采用分块加载策略

六、与其他模块的协同工作

6.1 与词嵌入的集成

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, device, dropout):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.dropout(tok_emb + pos_emb)
  • 加法融合:通过element-wise相加实现信息融合
  • 梯度隔离:位置编码不参与梯度更新
  • 维度验证:确保tok_embpos_emb维度严格一致

七、典型应用场景分析

7.1 文本生成任务

  • 长序列处理:通过位置编码捕获远距离依赖
  • 解码器优化:在自回归生成时动态调整位置编码

7.2 语音识别系统

  • 时序建模:精确捕捉语音信号的时序特征
  • 多尺度编码:结合不同频率分量处理语音信号

八、扩展研究方向

  1. 相对位置编码:改进绝对位置编码的局限性
  2. 动态频率调整:根据输入数据自动调节频率参数
  3. 混合编码方案:结合可学习参数与固定编码
  4. 量子化压缩:对编码矩阵进行低比特量化

原项目代码(附)

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""

import torch
from torch import nn

# 定义一个名为PositionalEncoding的类,它继承自nn.Module,用于计算正弦位置编码。
class PositionalEncoding(nn.Module):
    """
    计算正弦位置编码的类。
    """

    def __init__(self, d_model, max_len, device):
        """
        PositionalEncoding类的构造函数。

        :param d_model: 模型的维度(即嵌入向量的大小)。
        :param max_len: 序列的最大长度。
        :param device: 硬件设备设置(CPU或GPU)。
        """

        super(PositionalEncoding, self).__init__()  # 调用父类nn.Module的构造函数。

        # 初始化一个与输入矩阵大小相同的零矩阵,用于存储位置编码,以便后续与输入矩阵相加。
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 我们不需要计算位置编码的梯度。

        # 创建一个从0到max_len-1的一维张量,表示序列中的位置索引。
        pos = torch.arange(0, max_len, device=device)
        # 将位置索引张量转换为浮点数,并增加一个维度,从1D变为2D,以表示每个位置的索引。
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D,增加维度以表示单词的位置。

        # 创建一个从0到d_model-1,步长为2的一维浮点数张量,用于计算正弦和余弦函数的指数部分。
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i'表示d_model的索引(例如,嵌入大小=50时,'i'的范围为[0,50])。
        # "step=2"意味着'i'每次增加2(相当于2*i)。

        # 使用正弦函数计算位置编码的偶数索引位置的值。
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        # 使用余弦函数计算位置编码的奇数索引位置的值。
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # 计算位置编码,以考虑单词的位置信息。

    def forward(self, x):
        # self.encoding是预先计算好的位置编码矩阵。
        # [max_len = 512, d_model = 512],表示最大长度为512,维度为512的位置编码。

        # 获取输入x的批次大小和序列长度。
        batch_size, seq_len = x.size()
        # [batch_size = 128, seq_len = 30],表示批次大小为128,序列长度为30。

        # 返回与输入序列长度相匹配的位置编码。
        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512],返回的形状为序列长度乘以维度。
        # 它将与输入嵌入(tok_emb)相加,tok_emb的形状通常为[128, 30, 512]。


http://www.kler.cn/a/570669.html

相关文章:

  • 机器学习11-经典网络解析
  • AI语音交互模组方案,设备无线物联网控制,实时语音联动应用
  • 数据结构:二叉搜索树(排序树)
  • Redis高可用部署:3台服务器打造哨兵集群
  • 基于 Rust 与 GBT32960 规范的编解码层
  • 动态表头报表的绘制与导出
  • 基于 Elasticsearch 和 Milvus 的 RAG 运维知识库的架构设计和部署落地实现指南
  • 深入剖析Java NIO的epoll机制:红黑树、触发模式与CPU缓存优化
  • 运动想象 (MI) 分类学习系列 (17) : CCSM-FT
  • OCR PDF 文件是什么?它包含什么内容?
  • 力扣 最长回文子串
  • M4 Mac mini运行DeepSeek-R1模型
  • 03.03 QT
  • 如何本地部署大模型及性能优化指南(附避坑要点)
  • AI预测福彩3D新模型百十个定位预测+胆码预测+杀和尾+杀和值2025年3月3日第11弹
  • WordPress ltl-freight-quotes-estes-edition sql注入漏洞(CVE-2024-13479)
  • Linux虚拟机网络配置-桥接网络配置
  • 【向量数据库Weaviate】与ChromaDB的差异、优劣
  • 刚安装docker并启动docker服务: systemctl restart docker报错解决
  • [RN]React Native知识框架图详解