当前位置: 首页 > article >正文

Transformer中的嵌入位置编码

在Transformer中,使用余弦编码或其他类似的编码方式(如正弦-余弦位置编码)而不是简单的“0123456”这种数字编码,主要是因为位置编码的目标是为模型提供位置信息,同时又不引入过多的显式顺序假设。

主要原因如下:

  1. 避免数字编码的离散性: 如果使用简单的数字编码(如0, 1, 2, 3, …),这种编码方法会暗示数字之间有某种固定的数学关系,而实际上,位置之间的关系是相对的而非线性的。例如,在“0123456”这种编码下,位置1和位置2之间的差异与位置6和位置7之间的差异是相同的,但它们在语义上可能并不等价。余弦编码则避免了这种线性关系,它能以周期性的方式映射每个位置,使得模型能够更灵活地学习不同位置之间的关系。

  2. 周期性特性: 余弦编码是基于正弦和余弦函数的,具有自然的周期性。这对于处理循环性质或长距离依赖(如句子的开始和结束、或长期的语法结构)尤其有效。余弦编码的这种周期性特性使得模型可以捕捉到位置之间的相对关系,不管它们之间的距离有多远。

  3. 无固定顺序假设: Transformer的自注意力机制(Self-Attention)是无序的,也就是说,模型本身并不假设序列的顺序是固定的,位置编码的引入是为了补充这个信息。如果使用“0123456”这种数字编码,模型可能会学习到数字之间的大小顺序,而这是不必要的,尤其在处理长文本时,顺序本身的编码可能会导致模型偏向某些固定模式。使用余弦编码则帮助模型从不同的角度(不同的频率)感知位置关系,而不是仅仅依赖于线性编码。

  4. 高效的学习能力: 余弦编码和正弦编码的连续性使得模型在学习时可以更容易地捕捉位置之间的相对关系,特别是对于不同长度的序列。它们的频率变化也帮助模型将不同位置的编码“拉开”距离,使得模型能够更清晰地区分不同的位置信息。

import torch
import math

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # 创建一个位置编码矩阵,大小为 max_len x d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))  # shape: (d_model / 2,)
        
        # 计算每个位置的编码
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置
        
        pe = pe.unsqueeze(0)  # 增加batch维度,shape: (1, max_len, d_model)
        
        self.register_buffer('pe', pe)  # 注册为buffer,不会更新
        
    def forward(self, x):
        # x: 输入的张量,shape: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1)].detach()

# 示例
d_model = 512  # 嵌入维度
max_len = 60   # 序列最大长度
position_encoding = PositionalEncoding(d_model, max_len)

# 假设输入一个批次的序列,shape为 (batch_size, seq_len, d_model)
batch_size = 32
seq_len = 50
x = torch.randn(batch_size, seq_len, d_model)

# 加入位置编码
x_pos = position_encoding(x)
print(x_pos.shape)  # 应该是 (batch_size, seq_len, d_model)

余弦编码和其他类似的连续位置编码方式提供了一个能够捕捉更复杂的位置信息的机制,而简单的数字编码往往过于局限,难以适应Transformer模型的自注意力机制及其灵活的处理能力。


http://www.kler.cn/a/539187.html

相关文章:

  • 点大商城V2-2.6.6源码全开源uniapp +搭建教程
  • 机器学习笔记
  • 了解“/linux-5.4.31/drivers/of/device.c”中的of_device_get_match_data()
  • RabbitMq入门
  • Docker Desktop安装到其他盘
  • 计算机视觉的研究方向、发展历程、发展前景介绍
  • Golang:Go 1.23 版本新特性介绍
  • 小程序实现消息订阅通知完整实践及踩坑记录
  • AI绘画:开启艺术与科技融合的未来之门(10/10)
  • Unity3D仿星露谷物语开发28之切换场景
  • 【神经网络框架】非局部神经网络
  • [LeetCode]day18 202.快乐数
  • Redis的数据过期策略和数据淘汰策略
  • 【计算机视觉】多分辨率金字塔全解析 ✨
  • 机试题——D路通信
  • Sparse4D v3:推进端到端3D检测和跟踪
  • Android系统SELinux详解
  • 携手AWS,零成本在EKS上体验AutoMQ企业版
  • 计算机网络知识速记:TCP 与 UDP
  • 六、OSG学习笔记-漫游(操作器)
  • ViewModel和LiveData
  • ES6中的模板字符串
  • 2025年2月9日(数据分析,在最高点和最低点添加注释,添加水印)
  • 面向对象设计在Java程序开发中的最佳实践研究
  • 【服务器知识】如何在linux系统上搭建一个nfs
  • springboot 事务管理