当前位置: 首页 > article >正文

多头自注意力中的多头作用及相关思考

文章目录

  • 1. num_heads
  • 2. pytorch源码演算

1. num_heads

将矩阵的最后一维度进行按照num_heads的方式进行切割矩阵,具体表示如下:
在这里插入图片描述
在这里插入图片描述

2. pytorch源码演算

  • pytorch 代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    seq_len = 4
    model_dim = 6
    num_heads = 3
    mat_total = batch_size * seq_len * model_dim
    mat1 = torch.arange(mat_total).reshape((batch_size, seq_len, model_dim))
    print(f"mat1=\n{mat1}")
    head_dim = model_dim // num_heads
    mat2 = mat1.reshape((batch_size, seq_len, num_heads, head_dim))
    print(f"mat2=\n{mat2}")
    mat3 = mat2.transpose(1, 2)
    print(f"mat3=\n{mat3}")
    mat4 = mat3.reshape((batch_size*num_heads,seq_len,head_dim))
    print(f"mat1.shape=\n{mat1.shape}")
    print(f"mat1=\n{mat1}")

    print(f"mat4.shape=\n{mat4.shape}")
    print(f"mat4=\n{mat4}")
  • 结果:
mat1=
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23]],

        [[24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41],
         [42, 43, 44, 45, 46, 47]]])
mat2=
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],

         [[ 6,  7],
          [ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15],
          [16, 17]],

         [[18, 19],
          [20, 21],
          [22, 23]]],


        [[[24, 25],
          [26, 27],
          [28, 29]],

         [[30, 31],
          [32, 33],
          [34, 35]],

         [[36, 37],
          [38, 39],
          [40, 41]],

         [[42, 43],
          [44, 45],
          [46, 47]]]])
mat3=
tensor([[[[ 0,  1],
          [ 6,  7],
          [12, 13],
          [18, 19]],

         [[ 2,  3],
          [ 8,  9],
          [14, 15],
          [20, 21]],

         [[ 4,  5],
          [10, 11],
          [16, 17],
          [22, 23]]],


        [[[24, 25],
          [30, 31],
          [36, 37],
          [42, 43]],

         [[26, 27],
          [32, 33],
          [38, 39],
          [44, 45]],

         [[28, 29],
          [34, 35],
          [40, 41],
          [46, 47]]]])
mat1.shape=
torch.Size([2, 4, 6])
mat1=
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23]],

        [[24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41],
         [42, 43, 44, 45, 46, 47]]])
mat4.shape=
torch.Size([6, 4, 2])
mat4=
tensor([[[ 0,  1],
         [ 6,  7],
         [12, 13],
         [18, 19]],

        [[ 2,  3],
         [ 8,  9],
         [14, 15],
         [20, 21]],

        [[ 4,  5],
         [10, 11],
         [16, 17],
         [22, 23]],

        [[24, 25],
         [30, 31],
         [36, 37],
         [42, 43]],

        [[26, 27],
         [32, 33],
         [38, 39],
         [44, 45]],

        [[28, 29],
         [34, 35],
         [40, 41],
         [46, 47]]])
  • 思考: 在矩阵y=Ax表示的时候,如果我们无法用Ax整体表示y的时候,我们可以通过将矩阵A的列向量进行拆分后得到A1,A2,A3,这样y=(A1,A2,A3)x表示更合理。

http://www.kler.cn/a/540320.html

相关文章:

  • 指定路径安装Ollama
  • 【MQ】Spring3 中 RabbitMQ 的使用与常见场景
  • Chapter2:C#基本数据类型
  • 使用 Postman 进行 API 测试:从入门到精通
  • SpringBoot速成(七)注册实战P2-P4
  • Deepseek的MLA技术原理介绍
  • 《我在技术交流群算命》(三):QML的Button为什么有个蓝框去不掉啊(QtQuick.Controls由Qt5升级到Qt6的异常)
  • 深入理解QT的View-Model-Delegate机制和用法
  • 开发指南098-logback-spring.xml说明
  • C# 学习目录
  • 海外直播场景下的AWS技术架构设计与实践
  • 【医院管理会计专题】2.管理会计:医院运营管理的隐形引擎
  • AutoMQ 如何实现没有写性能劣化的极致冷读效率
  • 11g ADG主备切换步骤
  • 【JAVA使用Aes加密报错:Illegal key size or default parameters,如何解决?】
  • FFmpeg 学习路径
  • VeryReport和FineReport两款报表软件深度分析对比
  • 只需三步!5分钟本地部署deep seek——MAC环境
  • MongoDB 的使用场景
  • Transformers as SVM(2023 NIPS)
  • react概览webpack基础
  • zynq tcp万兆网和ftp协议分析
  • 如何查看用户的详细身份信息
  • 向量数据库简单对比
  • fps动作系统9:动画音频
  • flutter 默认跳转封装