当前位置: 首页 > article >正文

普通的2D Average pooling是怎么进行backward的呢?

 

 

二维平均池层计算损失相对于其输入张量的梯度,方法是将损失相对于输出张量的梯度均分在输入子区域,这些子区域在前向传播中被用来计算平均值。

由于平均集合计算的是每个输入子区域的平均值,所以子区域中的每个元素对平均值的贡献是相同的。因此,在反向传播过程中,与子区域中每个元素有关的损失梯度等于与相应的输出元素有关的损失梯度除以子区域的大小

import torch
import torch.nn as nn

avg_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))

x = torch.randn(1, 1, 4, 4, requires_grad=True)
y = avg_pool(x)

# Compute some loss
loss = y.sum()
print(loss) # tensor(-0.1558, grad_fn=<SumBackward0>)

# Compute gradients
loss.backward()

# Check gradients
print(x.grad)
# tensor([[[[ 0.0625,  0.0625, -0.1875, -0.1875],
#           [ 0.0625,  0.0625, -0.1875, -0.1875],
#           [-0.1250, -0.1250, -0.0625, -0.0625],
#           [-0.1250, -0.1250, -0.0625, -0.0625]]]])

结果就是:

tensor([[[[0.2500, 0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500, 0.2500]]]])

解读:

在这个例子中,创建了一个具有kernel_size=(2 ,2) and stride=(2 ,2)的AvgPool2d层。一个形状为(1 ,1 ,4 ,4)且require_grad=True的输入张量x通过该层,产生一个形状为(1 ,1 ,2 ,2)的输出张量y。

损失被计算为y中所有元素之和,其值为-0.1558。在反向传播过程中,损失相对于x的梯度是通过将损失相对于y的梯度分布在用于计算y的平均值的(2,2)输入子区域来计算的。

在这种情况下,由于y中有四个输出元素,它们的梯度是[1 ,1 ,1 ,1]。这些梯度分布在各自的(2 ,2)输入子区域,子区域的每个元素得到的梯度等于其对应的输出元素的梯度除以(2 * 2) = 4。因此,x.grad中的元素的值为[1/4 ,1/4 ,1/4 ,1/4]。


http://www.kler.cn/a/16194.html

相关文章:

  • 斯坦福泡茶机器人DexCap源码解析:涵盖收集数据、处理数据、模型训练三大阶段
  • 数据产品:深度探索与案例剖析
  • 【JAVA基础】JVM是什么?
  • 深入解析贪心算法及其应用实例
  • 如何保护 Microsoft 网络免受中间人攻击
  • Jetpack 之 Ink API初探
  • [Pandas] 查看DataFrame的常用属性
  • 云原生CAx软件:多租户的认证
  • MySQL数据库,JDBC连接数据库操作流程详细介绍
  • 西门子PLC沿脉冲类指令汇总
  • 5.5.1哈夫曼树
  • GDKOI 2023游记总结
  • 【BeautifulSoup上】——05全栈开发——如桃花来
  • Afkayas.1(★)
  • 学习系统编程No.20【进程间通信之命名管道】
  • 大数据架构(一)背景和概念
  • 从0搭建Vue3组件库(十一): 集成项目的编程规范工具链(ESlint+Prettier+Stylelint)
  • 盈泰德带你了解产品表面缺陷检测系统
  • Idea关闭或开启引用提示Usages和Annotations
  • Vulnhub:DerpNStink 1靶机
  • C语言程序设计:某班有5名同学,建立一个学生的简单信息表,包括学号、姓名、3门课程的成绩,编写程序,计算每名学生的平均成绩及名次。
  • 配置Bridge模式KVM虚拟机
  • 第六章结构型模式—代理模式
  • Springboot +Flowable,设置任务处理人的四种方式(一)
  • Android Java 音频采集 AudioRecord
  • 【C++】类和对象