【AI知识】pytorch手写Attention之Self-Attention,Multi-Head-Attention
pytorch手写Attention
- Self-Attention
- Multi-Head-Attention
Self-Attention
代码:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self,embed_dim):
super(SelfAttention,self).__init__()
self.embed_dim=embed_dim
self.WQ=nn.Linear(embed_dim,embed_dim)
self.WK=nn.Linear(embed_dim,embed_dim)
self.WV=nn.Linear(embed_dim,embed_dim)
self.dropout=nn.Dropout(0.1)
def forward(self,x,mask=None):
"""
输入序列x(batch_size,seq_len,embed_dim)
"""
# (batch_size,seq_len,embed_dim)
Q=self.WQ(x)
# (batch_size,seq_len,embed_dim)
K=self.WK(x)
# (batch_size,seq_len,embed_dim)
V=self.WV(x)
# K(batch_size,seq_len,embed_dim) ,K.transpose(-2,-1)交换张量的最后一个维度和倒数第二个维度
attention_scores=torch.matmul(Q,K.transpose(-2,-1))/(self.embed_dim**0.5)
# 被掩码的位置设为 -inf
if mask is not None:
attention_scores=attention_scores.masked_fill(mask==0,float('-inf'))
# 沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度),对 seq_len 维度计算,让每个 Query 的注意力总和为 1
attention_weights=torch.softmax(attention_scores,-1)
output=torch.matmul(attention_weights,V)
return output,attention_weights
def create_causal_mask(seq_len):
"""
生成一个 (seq_len, seq_len) 的上三角矩阵
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) # 生成上三角部分为 1
return mask == 0 # 0 表示被掩码,1 表示可用
测试:
batch_size=2
seq_len=5
embed_dim=5
x=torch.rand(batch_size,seq_len,embed_dim)
mask_true=create_causal_mask(seq_len)
print("mask:")
print(mask_true)
print("-"*50)
mask_false=None
self_attention=SelfAttention(embed_dim)
output_with_mask,weights_with_mask=self_attention(x,mask_true)
print("output_with_mask:")
print(output_with_mask)
print("weights_with_mask:")
print(weights_with_mask)
print("-"*50)
output_without_mask,weights_without_mask=self_attention(x,mask_false)
print("output_without_mask:")
print(output_without_mask)
print("weights_without_mask:")
print(weights_without_mask)
结果:
mask:
tensor([[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]])
--------------------------------------------------
output_with_mask:
tensor([[[-0.0250, -0.0048, 0.1955, 0.1222, 0.3228],
[ 0.0082, -0.1107, 0.2676, 0.1467, 0.4512],
[ 0.0056, -0.1283, 0.3186, 0.1582, 0.3351],
[ 0.0030, -0.0939, 0.2760, 0.1519, 0.3447],
[ 0.0192, -0.1143, 0.3045, 0.1725, 0.3280]],
[[-0.0036, -0.4164, 0.3813, 0.2492, 0.5639],
[-0.1500, -0.3267, 0.2072, 0.0787, 0.4852],
[-0.0660, -0.2758, 0.2731, 0.1216, 0.4643],
[-0.0864, -0.2271, 0.2297, 0.0849, 0.4300],
[-0.0653, -0.1743, 0.2279, 0.0985, 0.3965]]],
grad_fn=<UnsafeViewBackward0>)
weights_with_mask:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4885, 0.5115, 0.0000, 0.0000, 0.0000],
[0.3256, 0.3318, 0.3427, 0.0000, 0.0000],
[0.2447, 0.2653, 0.2568, 0.2332, 0.0000],
[0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5361, 0.4639, 0.0000, 0.0000, 0.0000],
[0.3654, 0.3014, 0.3332, 0.0000, 0.0000],
[0.2744, 0.2371, 0.2544, 0.2340, 0.0000],
[0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],
grad_fn=<SoftmaxBackward0>)
--------------------------------------------------
output_without_mask:
tensor([[[ 0.0190, -0.1140, 0.3034, 0.1720, 0.3307],
[ 0.0191, -0.1138, 0.3035, 0.1722, 0.3295],
[ 0.0191, -0.1133, 0.3035, 0.1724, 0.3276],
[ 0.0191, -0.1147, 0.3040, 0.1720, 0.3311],
[ 0.0192, -0.1143, 0.3045, 0.1725, 0.3280]],
[[-0.0623, -0.1730, 0.2303, 0.1011, 0.3960],
[-0.0615, -0.1707, 0.2301, 0.1013, 0.3945],
[-0.0601, -0.1736, 0.2327, 0.1035, 0.3966],
[-0.0621, -0.1722, 0.2301, 0.1011, 0.3955],
[-0.0653, -0.1743, 0.2279, 0.0985, 0.3965]]],
grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
[-0.0653, -0.1743, 0.2279, 0.0985, 0.3965]]],
grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
tensor([[[0.1964, 0.2094, 0.2038, 0.1892, 0.2012],
[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
[0.1961, 0.2053, 0.2040, 0.1902, 0.2044],
[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
[0.1953, 0.1990, 0.2055, 0.1928, 0.2075],
[0.1958, 0.2123, 0.2055, 0.1866, 0.1998],
[0.1945, 0.2031, 0.2080, 0.1886, 0.2058]],
[[0.2205, 0.1906, 0.2044, 0.1864, 0.1981],
[0.2171, 0.1878, 0.2034, 0.1879, 0.2038],
[0.2250, 0.1856, 0.2051, 0.1834, 0.2009],
[0.2195, 0.1897, 0.2035, 0.1872, 0.2002],
[0.2194, 0.1978, 0.2025, 0.1885, 0.1918]]],
grad_fn=<SoftmaxBackward0>)
代码中的一些用法解释:
1)torch.nn.Linear ()
torch.nn.Linear() 是 PyTorch 最基础的全连接层(线性变换层),用于执行以下操作:
nn.Linear(in_features, out_features, bias=True)
,in_features 是输入特征维度,out_features是输出特征维度,bias表示是否使用偏置项,默认为 True
2)K.transpose(-2,-1)
在 PyTorch 中,torch.transpose() 用于交换张量的两个维度。参数 -2 和 -1 是指张量的倒数第二个维度和最后一个维度。
K.transpose(-2, -1) 和 K.transpose(-1, -2) 都是交换最后两个维度,它们的效果完全相同。
3)torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
def create_causal_mask(seq_len):
"""
生成一个 (seq_len, seq_len) 的上三角矩阵
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) # 生成上三角部分为 1
return mask == 0 # 0 表示被掩码,1 表示可用
此函数解释:
- torch.ones(seq_len, seq_len) 生成 seq_len × seq_len 矩阵,所有元素都是 1
- torch.triu(…, diagonal=1) 取 上三角(不包括主对角线),上三角是 1,其余 0
- mask == 0 把 0 变成 True(可用),把 1 变成 False(被屏蔽)
这样返回一个seq_len=5的mask矩阵:
tensor([[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]])
4)attention_scores.masked_fill(mask==0,float('-inf'))
作用: 根据 mask矩阵 进行掩码处理,将 mask == 0 的位置填充为 -inf(负无穷),使其在 softmax 计算时权重变为 0
masked_fill 属于 PyTorch 张量(torch.Tensor)的方法,用于根据布尔掩码(mask)填充指定值。masked_fill
的语法: tensor.masked_fill(mask, value)
tensor:要修改的张量
mask:布尔掩码(True/False 或 0/1)
value:要填充的值(如 -inf)
解释:attention_scores.masked_fill(mask==0,float('-inf'))
mask == 0 选取 应该被屏蔽的位置(即 上三角部分)
masked_fill(mask == 0, -inf) 把上三角部分设为-inf
这样,Softmax 后被屏蔽的部分变成 0,不会影响注意力计算。
5) torch.softmax(attention_scores,-1)
作用: 对 attention_scores 进行 Softmax 归一化,确保注意力权重(attn_weights)的总和为 1,控制每个 Token 对序列中其他 Token 的关注程度
torch.softmax(input, dim)
语法:
input:要进行 Softmax 计算的张量
dim:沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度)
Multi-Head-Attention
import torch
import torch.nn as nn
import math
class Multi_Head_Attention(nn.Module):
def __init__(self,embed_dim,nums_heads):
super(Multi_Head_Attention,self).__init__()
assert embed_dim % nums_heads ==0,"embed_dim 必须能被 num_heads 整除"
self.embed_dim=embed_dim
self.nums_heads=nums_heads
self.head_dim=embed_dim//nums_heads
self.WQ=nn.Linear(embed_dim,embed_dim)
self.WK=nn.Linear(embed_dim,embed_dim)
self.WV=nn.Linear(embed_dim,embed_dim)
self.fc=nn.Linear(embed_dim,embed_dim)
self.scale=math.sqrt(embed_dim)
def forward(self,x,mask=None):
batch_size,seq_len,embed_dim=x.shape
# Q,K,V: batch_size,seq_len,embed_dim
Q=self.WQ(x)
K=self.WK(x)
V=self.WV(x)
# Q,K,V: batch_size,seq_len,embed_dim -> batch_size,seq_len,self.nums_heads,self.head_dim -> batch_size,self.nums_heads,seq_len,self.head_dim
Q=Q.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
K=K.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
V=V.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
# batch_size,self.nums_heads,seq_len,seq_len
attn_scores=torch.matmul(Q,K.transpose(-1,-2))/self.scale
if mask is not None:
attn_scores=attn_scores.masked_fill(mask==0,float('-inf'))
attn_weights=torch.softmax(attn_scores,-1)
# batch_size,self.nums_heads,seq_len,head_dim
output=torch.matmul(attn_weights,V)
# batch_size,self.nums_heads,seq_len,head_dim -> batch_size,seq_len,self.nums_heads,head_dim -> batch_size,seq_len,self.embed_dim
output=output.transpose(1,2).contiguous().view(batch_size,seq_len,self.embed_dim)
output=self.fc(output)
return output,attn_weights
def create_mask(seq_len):
mask=torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
return mask==0
测试:
batch_size=2
seq_len=5
nums_heads=3
embed_dim=6
x=torch.randn(batch_size,seq_len,embed_dim)
mask=create_mask(seq_len)
multiheadattention=Multi_Head_Attention(embed_dim,nums_heads)
output,weights=multiheadattention(x,mask)
print(output)
print(weights)