当前位置: 首页 > article >正文

【AI知识】pytorch手写Attention之Self-Attention,Multi-Head-Attention

pytorch手写Attention

  • Self-Attention
  • Multi-Head-Attention

Self-Attention

代码:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self,embed_dim):
        super(SelfAttention,self).__init__()
        self.embed_dim=embed_dim

        self.WQ=nn.Linear(embed_dim,embed_dim)
        self.WK=nn.Linear(embed_dim,embed_dim)
        self.WV=nn.Linear(embed_dim,embed_dim)
        self.dropout=nn.Dropout(0.1)

    def forward(self,x,mask=None):
        """
        输入序列x(batch_size,seq_len,embed_dim)
        """
        # (batch_size,seq_len,embed_dim)
        Q=self.WQ(x) 
        # (batch_size,seq_len,embed_dim)
        K=self.WK(x)
        # (batch_size,seq_len,embed_dim)
        V=self.WV(x)
        # K(batch_size,seq_len,embed_dim) ,K.transpose(-2,-1)交换张量的最后一个维度和倒数第二个维度
        attention_scores=torch.matmul(Q,K.transpose(-2,-1))/(self.embed_dim**0.5)
        # 被掩码的位置设为 -inf
        if mask is not None:
            attention_scores=attention_scores.masked_fill(mask==0,float('-inf'))
        # 沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度),对 seq_len 维度计算,让每个 Query 的注意力总和为 1
        attention_weights=torch.softmax(attention_scores,-1)
        output=torch.matmul(attention_weights,V)
        return output,attention_weights
    
def create_causal_mask(seq_len):
    """
    生成一个 (seq_len, seq_len) 的上三角矩阵
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 生成上三角部分为 1
    return mask == 0  # 0 表示被掩码,1 表示可用

测试:

batch_size=2
seq_len=5
embed_dim=5
x=torch.rand(batch_size,seq_len,embed_dim)

mask_true=create_causal_mask(seq_len)
print("mask:")
print(mask_true)
print("-"*50)
mask_false=None

self_attention=SelfAttention(embed_dim)

output_with_mask,weights_with_mask=self_attention(x,mask_true)
print("output_with_mask:")
print(output_with_mask)
print("weights_with_mask:")
print(weights_with_mask)
print("-"*50)

output_without_mask,weights_without_mask=self_attention(x,mask_false)
print("output_without_mask:")
print(output_without_mask)
print("weights_without_mask:")
print(weights_without_mask)

结果:

mask:
tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])
--------------------------------------------------
output_with_mask:
tensor([[[-0.0250, -0.0048,  0.1955,  0.1222,  0.3228],
         [ 0.0082, -0.1107,  0.2676,  0.1467,  0.4512],
         [ 0.0056, -0.1283,  0.3186,  0.1582,  0.3351],
         [ 0.0030, -0.0939,  0.2760,  0.1519,  0.3447],
         [ 0.0192, -0.1143,  0.3045,  0.1725,  0.3280]],

        [[-0.0036, -0.4164,  0.3813,  0.2492,  0.5639],
         [-0.1500, -0.3267,  0.2072,  0.0787,  0.4852],
         [-0.0660, -0.2758,  0.2731,  0.1216,  0.4643],
         [-0.0864, -0.2271,  0.2297,  0.0849,  0.4300],
         [-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],
       grad_fn=<UnsafeViewBackward0>)
weights_with_mask:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4885, 0.5115, 0.0000, 0.0000, 0.0000],
         [0.3256, 0.3318, 0.3427, 0.0000, 0.0000],
         [0.2447, 0.2653, 0.2568, 0.2332, 0.0000],
         [0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5361, 0.4639, 0.0000, 0.0000, 0.0000],
         [0.3654, 0.3014, 0.3332, 0.0000, 0.0000],
         [0.2744, 0.2371, 0.2544, 0.2340, 0.0000],
         [0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],
       grad_fn=<SoftmaxBackward0>)
--------------------------------------------------
output_without_mask:
tensor([[[ 0.0190, -0.1140,  0.3034,  0.1720,  0.3307],
         [ 0.0191, -0.1138,  0.3035,  0.1722,  0.3295],
         [ 0.0191, -0.1133,  0.3035,  0.1724,  0.3276],
         [ 0.0191, -0.1147,  0.3040,  0.1720,  0.3311],
         [ 0.0192, -0.1143,  0.3045,  0.1725,  0.3280]],

        [[-0.0623, -0.1730,  0.2303,  0.1011,  0.3960],
         [-0.0615, -0.1707,  0.2301,  0.1013,  0.3945],
         [-0.0601, -0.1736,  0.2327,  0.1035,  0.3966],
         [-0.0621, -0.1722,  0.2301,  0.1011,  0.3955],
         [-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],
       grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
         [0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
         [0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
         [-0.0653, -0.1743,  0.2279,  0.0985,  0.3965]]],
       grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
         [0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
         [0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
         [0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
         [0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
         [0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
         [0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
         [0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
         [0.1958, 0.2123, 0.2055, 0.1866, 0.1998],
         [0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],

        [[0.2205, 0.1906, 0.2044, 0.1864, 0.1981],
         [0.2171, 0.1878, 0.2034, 0.1879, 0.2038],
         [0.2250, 0.1856, 0.2051, 0.1834, 0.2009],
         [0.2195, 0.1897, 0.2035, 0.1872, 0.2002],
         [0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],
       grad_fn=<SoftmaxBackward0>)

代码中的一些用法解释:

1)torch.nn.Linear ()

torch.nn.Linear() 是 PyTorch 最基础的全连接层(线性变换层),用于执行以下操作:
在这里插入图片描述

nn.Linear(in_features, out_features, bias=True),in_features 是输入特征维度,out_features是输出特征维度,bias表示是否使用偏置项,默认为 True

2)K.transpose(-2,-1)

在 PyTorch 中,torch.transpose() 用于交换张量的两个维度。参数 -2 和 -1 是指张量的倒数第二个维度和最后一个维度。

K.transpose(-2, -1) 和 K.transpose(-1, -2) 都是交换最后两个维度,它们的效果完全相同。

3)torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

def create_causal_mask(seq_len):
    """
    生成一个 (seq_len, seq_len) 的上三角矩阵
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 生成上三角部分为 1
    return mask == 0  # 0 表示被掩码,1 表示可用

此函数解释:

  • torch.ones(seq_len, seq_len) 生成 seq_len × seq_len 矩阵,所有元素都是 1
  • torch.triu(…, diagonal=1) 取 上三角(不包括主对角线),上三角是 1,其余 0
  • mask == 0 把 0 变成 True(可用),把 1 变成 False(被屏蔽)

这样返回一个seq_len=5的mask矩阵:

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

4)attention_scores.masked_fill(mask==0,float('-inf'))

作用: 根据 mask矩阵 进行掩码处理,将 mask == 0 的位置填充为 -inf(负无穷),使其在 softmax 计算时权重变为 0

masked_fill 属于 PyTorch 张量(torch.Tensor)的方法,用于根据布尔掩码(mask)填充指定值。masked_fill 的语法: tensor.masked_fill(mask, value)

tensor:要修改的张量
mask:布尔掩码(True/False 或 0/1)
value:要填充的值(如 -inf)

解释:attention_scores.masked_fill(mask==0,float('-inf'))

mask == 0 选取 应该被屏蔽的位置(即 上三角部分)
masked_fill(mask == 0, -inf) 把上三角部分设为-inf
这样,Softmax 后被屏蔽的部分变成 0,不会影响注意力计算。

5) torch.softmax(attention_scores,-1)
作用: 对 attention_scores 进行 Softmax 归一化,确保注意力权重(attn_weights)的总和为 1,控制每个 Token 对序列中其他 Token 的关注程度

torch.softmax(input, dim) 语法:

input:要进行 Softmax 计算的张量
dim:沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度)

Multi-Head-Attention

import torch
import torch.nn as nn
import math

class Multi_Head_Attention(nn.Module):
    def __init__(self,embed_dim,nums_heads):
        super(Multi_Head_Attention,self).__init__()
        
        assert embed_dim % nums_heads ==0,"embed_dim 必须能被 num_heads 整除"
        self.embed_dim=embed_dim
        self.nums_heads=nums_heads
        self.head_dim=embed_dim//nums_heads

        self.WQ=nn.Linear(embed_dim,embed_dim)
        self.WK=nn.Linear(embed_dim,embed_dim)
        self.WV=nn.Linear(embed_dim,embed_dim)

        self.fc=nn.Linear(embed_dim,embed_dim)
        self.scale=math.sqrt(embed_dim)

    def forward(self,x,mask=None):
        batch_size,seq_len,embed_dim=x.shape
        # Q,K,V: batch_size,seq_len,embed_dim
        Q=self.WQ(x)
        K=self.WK(x)
        V=self.WV(x)
        # Q,K,V: batch_size,seq_len,embed_dim -> batch_size,seq_len,self.nums_heads,self.head_dim -> batch_size,self.nums_heads,seq_len,self.head_dim
        Q=Q.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        K=K.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        V=V.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        # batch_size,self.nums_heads,seq_len,seq_len
        attn_scores=torch.matmul(Q,K.transpose(-1,-2))/self.scale
        
        if mask is not None:
            attn_scores=attn_scores.masked_fill(mask==0,float('-inf'))
        attn_weights=torch.softmax(attn_scores,-1)
        # batch_size,self.nums_heads,seq_len,head_dim
        output=torch.matmul(attn_weights,V)
        # batch_size,self.nums_heads,seq_len,head_dim -> batch_size,seq_len,self.nums_heads,head_dim -> batch_size,seq_len,self.embed_dim
        output=output.transpose(1,2).contiguous().view(batch_size,seq_len,self.embed_dim)
        output=self.fc(output)
        return output,attn_weights
    
def create_mask(seq_len):
    mask=torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
    return mask==0

测试:

batch_size=2
seq_len=5
nums_heads=3
embed_dim=6
x=torch.randn(batch_size,seq_len,embed_dim)
mask=create_mask(seq_len)

multiheadattention=Multi_Head_Attention(embed_dim,nums_heads)
output,weights=multiheadattention(x,mask)
print(output)
print(weights)

http://www.kler.cn/a/595484.html

相关文章:

  • vue3源码分析 -- computed
  • 深度解析学术论文成果评估(Artifact Evaluation):从历史到现状
  • 【问题解决】Postman 测试报错 406
  • 深入理解Java虚拟机(学习笔记)
  • java基础--序列化与反序列化的概念是什么?
  • 关于FastAPI框架的面试题及答案解析
  • 查看visual studio的MSVC版本的方法
  • 23 种设计模式中的访问者模式
  • 零基础上手Python数据分析 (7):Python 面向对象编程初步
  • 蓝桥杯 之 暴力回溯
  • 3.16[A]FPGA
  • Pytest基础使用
  • Netty源码—3.Reactor线程模型三
  • L2TP实验报告
  • 无服务器架构将淘汰运维?2025年云计算形态预测
  • RabbitMQ 与 Kafka:消息中间件的终极对比与选型指南
  • MSE分类时梯度消失的问题详解和交叉熵损失的梯度推导
  • Redis哨兵模式(Sentinel)高可用方案介绍与配置实践
  • 数字孪生技术引领UI前端设计新风尚:跨平台与响应式设计的结合
  • 【Bluebell】项目总结:基于 golang 的前后端分离 web 项目实战