注意力汇聚 笔记
10.2. 注意力汇聚:Nadaraya-Watson 核回归 — 动手学深度学习 2.0.0 documentation
要是有错请指正
想了半天这个玩意才搞懂一点点,
1.设计出一个函数 给予真实的连续x 得出真实连续的y
2.随机生成用于测试的一批不连续的 离散的x 并将测试x放进函数中并加入一定的偏差得到测试y
3.此时假设不知道函数的表达式 只用测试x和测试y在图上标点
3.1最简单的情况 将所有测试y取平均做一条直线拟合原函数 这也叫平均汇聚 ,结果明显是不行的
3.2 平均汇聚失败的原因是忽略了输入xi
4.我们将真实x复制50列 就是说 某一行所有列全部相同,每一行都不同构成(50,50)的大小真实x矩阵
真实x矩阵(简称a矩阵)再减去要预测的x_train 应用广播机制运算
结果就是比如a矩阵 第0行所有列数(每一个数都一样),减去x_train这个向量
实际上a矩阵的意义就是看每个真实x与x_train的距离,而我们可以看到其实只有第一个值距离最短.
也可以得知每行只有一个数是x与x_train距离最小的值,而为了防止距离太大比如将来的值可能会出现9999999这么长的数,用softmaxt归一化, 简单思想就是每个数占最大的是的一个百分比.归一化之后可以将矩阵变小很多. 为了方便看这个矩阵的距离画图和用颜色标出来
5.得出权重图,我们可以发现 两个x向量都是从小到大排序的
类似思想是y约等于x,但是这里y轴颠倒了,所以也类似于y等于-x在往上拉的样子,在这条直线上.两个值距离越近
而颜色越深就意味着距离越短,但是这里面数据并没有两个x都相等的情况 也就是x_test等于x_train的情况 所以是没有颜色为1的值的,最大也不到0.2
6.那么这个矩阵有什么用呢
从3.1的值我们y的平均值做一条直线 其实也可以描述来拟合这条曲线,但是他缺少对于xi的思考,比如说在xi的时候平均值应该往xi对应的yi上增加或者减少一点值,使得这条直线更加符合对于当前yi的一个值,如果加满了xi的这个值就直接变成了yi,也就是注意力拉满,但是对于全局来说不能这么看,所以还需要观看其他x的值,做一个全局观,所以需要做成矩阵a
也就是看xi的同时 主要集中注意力在xi,但是对于其他在xi隔壁甚至很远的x也要注意到,但是越远注意力就要分散一些,然后适用在这些y上面再变更y的值,最后使得最后拟合的这条曲线对于yi不会太过拟合,也不会欠拟合的样子
再说来,就是拟合的线即顾全局,也看当时的情况,但是注意力主要集中在当时的情况.
eg1.
如果直接平均汇聚,平均值是2.1868这条直线,真实值是0,而当时train的y也只有0.4415, 远小于这个值,但是其他y大,所以平均下来是有点离谱的,所以运用注意力机制 最后拟合的时候该点的y多看一点当前的yi少看一点其他的值(如果众生平等,大家的值都平等对待再平均就是y_mean,很大),就可以减少别的很大的y值对这个yi值的影响
紫色是最后拟合的线,根据不同注意力对不同y拉伸平均值的样子
总感觉上面说的话不太像人话
import torch
from torch import nn
from d2l import torch as d2l
n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本 rand是0到1之间 生成50个再*5 就是取0-5之间50个数
def f(x):
return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出
x_test = torch.arange(0, 5, 0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
n_test
def plot_kernel_reg(y_hat):
d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
xlim=[0, 5], ylim=[-1, 5])
d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
d2l.plt.show();
# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重. x_test每个值对x_train所有值的影响和关联
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
def forward(self, queries, keys, values):
# queries和attention_weights的形状为(查询个数,“键-值”对个数)
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = nn.functional.softmax(
-((queries - keys) * self.w)**2 / 2, dim=1)
# values的形状为(查询个数,“键-值”对个数)
return torch.bmm(self.attention_weights.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)
# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
trainer.zero_grad()
l = loss(net(x_train, keys, values), y_train)
l.sum().backward()
trainer.step()
print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
animator.add(epoch + 1, float(l.sum()))
# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
d2l.plt.show();