Attention Free Transformer (AFT)-2020论文笔记
名称:
Attention Free Transformer (AFT)
来源:
[2105.14103] An Attention Free Transformer
相关工作:
#Approximatingthedotproduct #Sparselocalattention #Contextcompression #Eliminatingdotproductattention #MLPsforvision
创新点:
贡献:
-
提出了一种全新的注意力机制替代方案,完全摒弃了点积注意力。
-
AFT的计算复杂度与输入长度和特征维度呈线性关系,适用于大规模数据。
-
AFT-local和AFT-conv变体通过引入局部性和空间权重共享,进一步提高了模型的效率和性能。
代码:
# ---------------------------------------
# 论文:An Attention Free Transformer (arxiv2021)
# ---------------------------------------
import torch
from torch import nn
from torch.nn import init
class AFT_FULL(nn.Module):
def __init__(self, d_model, n=49, simple=False):
super(AFT_FULL, self).__init__()
self.fc_q = nn.Linear(d_model, d_model)
self.fc_k = nn.Linear(d_model, d_model)
self.fc_v = nn.Linear(d_model, d_model)
if (simple):
self.position_biases = torch.zeros((n, n))
else:
self.position_biases = nn.Parameter(torch.ones((n, n)))
self.d_model = d_model
self.n = n
self.sigmoid = nn.Sigmoid()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, input):
bs, n, dim = input.shape
q = self.fc_q(input) # bs,n,dim
k = self.fc_k(input).view(1, bs, n, dim) # 1,bs,n,dim
v = self.fc_v(input).view(1, bs, n, dim) # 1,bs,n,dim
numerator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)) * v, dim=2) # n,bs,dim
denominator = torch.sum(torch.exp(k + self.position_biases.view(n, 1, -1, 1)), dim=2) # n,bs,dim
out = (numerator / denominator) # n,bs,dim
out = self.sigmoid(q) * (out.permute(1, 0, 2)) # bs,n,dim
return out
# 输入 B C N, 输出 B C Nif __name__ == '__main__':
block = AFT_FULL(d_model=512, n=64).cuda()
input = torch.rand(64, 64, 512).cuda()
output = block( input)
print(input.size(), output.size())