当前位置: 首页 > article >正文

注意力汇聚 笔记

10.2. 注意力汇聚:Nadaraya-Watson 核回归 — 动手学深度学习 2.0.0 documentation

要是有错请指正

想了半天这个玩意才搞懂一点点,

1.设计出一个函数 给予真实的连续x 得出真实连续的y

2.随机生成用于测试的一批不连续的 离散的x  并将测试x放进函数中并加入一定的偏差得到测试y

3.此时假设不知道函数的表达式 只用测试x和测试y在图上标点

3.1最简单的情况 将所有测试y取平均做一条直线拟合原函数 这也叫平均汇聚 ,结果明显是不行的

3.2 平均汇聚失败的原因是忽略了输入xi

4.我们将真实x复制50列 就是说 某一行所有列全部相同,每一行都不同构成(50,50)的大小真实x矩阵

真实x矩阵(简称a矩阵)再减去要预测的x_train  应用广播机制运算

结果就是比如a矩阵 第0行所有列数(每一个数都一样),减去x_train这个向量

 

实际上a矩阵的意义就是看每个真实x与x_train的距离,而我们可以看到其实只有第一个值距离最短.

也可以得知每行只有一个数是x与x_train距离最小的值,而为了防止距离太大比如将来的值可能会出现9999999这么长的数,用softmaxt归一化, 简单思想就是每个数占最大的是的一个百分比.归一化之后可以将矩阵变小很多. 为了方便看这个矩阵的距离画图和用颜色标出来

5.得出权重图,我们可以发现 两个x向量都是从小到大排序的

类似思想是y约等于x,但是这里y轴颠倒了,所以也类似于y等于-x在往上拉的样子,在这条直线上.两个值距离越近

而颜色越深就意味着距离越短,但是这里面数据并没有两个x都相等的情况 也就是x_test等于x_train的情况 所以是没有颜色为1的值的,最大也不到0.2  

6.那么这个矩阵有什么用呢

从3.1的值我们y的平均值做一条直线 其实也可以描述来拟合这条曲线,但是他缺少对于xi的思考,比如说在xi的时候平均值应该往xi对应的yi上增加或者减少一点值,使得这条直线更加符合对于当前yi的一个值,如果加满了xi的这个值就直接变成了yi,也就是注意力拉满,但是对于全局来说不能这么看,所以还需要观看其他x的值,做一个全局观,所以需要做成矩阵a

也就是看xi的同时 主要集中注意力在xi,但是对于其他在xi隔壁甚至很远的x也要注意到,但是越远注意力就要分散一些,然后适用在这些y上面再变更y的值,最后使得最后拟合的这条曲线对于yi不会太过拟合,也不会欠拟合的样子

再说来,就是拟合的线即顾全局,也看当时的情况,但是注意力主要集中在当时的情况.

eg1.

如果直接平均汇聚,平均值是2.1868这条直线,真实值是0,而当时train的y也只有0.4415, 远小于这个值,但是其他y大,所以平均下来是有点离谱的,所以运用注意力机制 最后拟合的时候该点的y多看一点当前的yi少看一点其他的值(如果众生平等,大家的值都平等对待再平均就是y_mean,很大),就可以减少别的很大的y值对这个yi值的影响

 

 

紫色是最后拟合的线,根据不同注意力对不同y拉伸平均值的样子  

总感觉上面说的话不太像人话

import torch
from torch import nn
from d2l import torch as d2l

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本 rand是0到1之间 生成50个再*5 就是取0-5之间50个数

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
d2l.plt.show();
# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重.   x_test每个值对x_train所有值的影响和关联
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')
d2l.plt.show();


http://www.kler.cn/a/4522.html

相关文章:

  • 【大数据】机器学习-----线性模型
  • 【2025 Rust学习 --- 17 文本和格式化 】
  • 【2024年华为OD机试】(C卷,100分)- 分割均衡字符串 (Java JS PythonC/C++)
  • MyBatis——XML映射文件
  • Qiskit快速编程探索(进阶篇)
  • 极客说|Azure AI Agent Service 结合 AutoGen/Semantic Kernel 构建多智能体解决⽅案
  • IO进程线程-标准IO(结)
  • 探究C/C++ typedef的秘密
  • Mysql排序后分页 分页数据有重复
  • python条件语句与循环语句
  • Unity游戏崩溃日志查询笔记 安卓平台 关于tombstone_00
  • TCC真没这么简单,一文讲透|分布式事务系列(三)
  • 面试官常问的设计模式及常用框架中设计模式的使用(一)
  • 树莓派学习笔记(八)树莓派Linux内核开发准备工作及概念
  • Java基础 -- 关键字Static和Final
  • docker-compose部署rabbitmq集群
  • 解决 Git 错误 error: failed to push some refs to ‘https://*****.git‘
  • 树莓派学习笔记(十三)基于框架编写驱动代码
  • 春分-面试
  • LeetCode:242. 有效的字母异位词
  • MySQL OCP888题解063-突然变慢的可能原因
  • 【Autoware规控】Lattice规划节点
  • CentOS挂载U盘拷贝文件
  • 【基础算法】1-2:归并排序
  • MyBatis-Plus联表查询(Mybatis-Plus-Join)
  • RabbitMQ高级