当前位置: 首页 > article >正文

【深度学习】矩阵操作万能函数 einsum-爱因斯坦求和

很不错的transformer 的学习仓库:https://github.com/tianxinliao/Transformer-learning,记录一下自用
ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126
在学习transformer的时候,看到代码里面有

        values = self.values(values)  # (N, value_len, embed_size)
        keys = self.keys(keys)  # (N, key_len, embed_size)
        queries = self.queries(query)  # (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

把我看蒙了,所以这次正经学习一下,看看咋回事。这个颇有一些只可意会不可言传的感觉,还是人菜瘾大,理解不深啊!

einsum 在numpy和torch中都有,借助了index–>(求和)

import torch
import torch.nn as nn
import torch.optim as optim
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape) # 矩阵乘法
print(torch.einsum('ij,kj->ki', x, v).shape) # 矩阵乘法 + T
print(torch.einsum('ij,km->ijkm', x, v).shape) # 这个算是一个拼接吧
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape)
print(torch.einsum('ij,kj->ki', x, v).shape)
print(torch.einsum('ij,km->ijkm', x, v).shape)
import torch
x = torch.tensor([
    [1, 2, 3],
    [4,5,6]
    ])
y = torch.tensor([
    [7,8,9]
    ])
x,y
(tensor([[1, 2, 3],
         [4, 5, 6]]),
 tensor([[7, 8, 9]]))
result = torch.einsum('ij,km->ijkm', x, y)
result
tensor([[[[ 7,  8,  9]],

         [[14, 16, 18]],

         [[21, 24, 27]]],


        [[[28, 32, 36]],

         [[35, 40, 45]],

         [[42, 48, 54]]]])
a = [
    [[1, 2],   # i=0
     [3, 4]],  # i=0
    [[5, 6],   # i=1
     [7, 8]]   #  i=1
]

b = [
     [[9, 10, 11], #  i=0
     [12, 13, 14]], #  i=0
     
    [[15, 16, 17], # i=1
     [18, 19, 20]]  # i=1
]

torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]).shape,torch.tensor(b[0]).shape

torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[1]) @ torch.tensor(b[1])
tensor([[183, 194, 205],
        [249, 264, 279]])
res = []
for i in range(len(a)):
    a1 = torch.tensor(a[i])
    b1 = torch.tensor(b[i])
    res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
res = []
for i in range(len(a)):
    a1 = torch.tensor(a[i])
    b1 = torch.tensor(b[i])
    res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
x = torch.rand(3, 3)
torch.einsum('ii->i', x),x
(tensor([0.7127, 0.3843, 0.2046]),
 tensor([[0.7127, 0.0171, 0.9940],
         [0.6781, 0.3843, 0.9031],
         [0.4963, 0.1581, 0.2046]]))

http://www.kler.cn/news/340064.html

相关文章:

  • ChatGPT 与 CoT 思维链:如何重塑 AI 的逻辑大脑?
  • verdaccio使用管理私自npm
  • 通过Keil5编译软件获取函数最深堆栈
  • 深入浅出MongoDB(六)
  • PHP 中浮点数 array_sum 求和精度丢失问题
  • 关于PPT生成的开源大模型总结
  • 【自动驾驶】控制算法(十二)横纵向综合控制 | 从理论到实战全面解析
  • 基于单片机的温度和烟雾检测
  • 解决雪花ID在前端精度丢失问题
  • TypeScript速成班:一篇文章搞定
  • attain和obtain区别
  • 无领导小组讨论|无领导小组讨论问题|无领导小组讨论答题框架
  • 数据结构之图(6)
  • 【unity进阶知识6】Resources的使用,如何封装一个Resources资源管理器
  • ps学习官方网址
  • 入门篇-1 数据结构简介
  • 零信任身份安全如何做到安全防护
  • 【web安全】——逻辑漏洞
  • 120页满分PPT | 企业级业务架构和IT架构规划方案
  • C++ 线程池设计