当前位置: 首页 > article >正文

利用前向勾子获取神经网络中间层的输出并将其进行保存(示例详解)

代码示例:

# 激活字典,用于保存每次的中间特征
activation = {}

# 将 forward_hook 函数定义在 upsample_v2 外部
def forward_hook(name):
    def hook(module, input, output):
        activation[name] = output.detach()

    return hook

def upsample_v2(in_channels, out_channels, upscale, kernel_size=3):
    layers = []

    # Define mid channel stages (three times reduction)
    mid_channels = [256, 128, 64]  # 512 32 32 -> 256 64 64 -> 128 128 128 -> 64 256 256 -> 2 256 256
    scale_factor_per_step = upscale ** (1/3)  # Calculate the scaling for each step

    current_in_channels = in_channels

    # Upsample and reduce channels in 3 steps
    for step, mid_channel in enumerate(mid_channels):
        # Conv layer to reduce number of channels
        conv = nn.Conv2d(current_in_channels, mid_channel, kernel_size=kernel_size, padding=1, bias=False)
        nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')
        layers.append(conv)

        # ReLU activation
        relu = nn.ReLU()
        layers.append(relu)

        # Upsampling layer
        up = nn.Upsample(scale_factor=scale_factor_per_step, mode='bilinear', align_corners=True)
        layers.append(up)

        layers[-1].register_forward_hook(forward_hook(f'step_{step}'))

        # Update current in_channels for the next layer
        current_in_channels = mid_channel

    conv = nn.Conv2d(current_in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)
    nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')
    layers.append(conv)

    return nn.Sequential(*layers)
def forward_hook(name):
    def hook(module, input, output):
        activation[name] = output.detach()

    return hook

forward_hook布置了抓取函数。其中,module代表你下面勾的那一层,input代表那一层的输入,output定义那一层的输出,我们常常只使用output。

layers[-1].register_forward_hook(forward_hook(f'step_{step}'))

这里定义了我需要捕获的那一层,layers[-1]代表我要捕获当前layers的最后一层,即上采用层,由于循环了三次,所以最后勾取的应当是三份中间层输出。


http://www.kler.cn/a/371781.html

相关文章:

  • 使用强化学习训练神经网络玩俄罗斯方块
  • 数据结构(1~10)
  • Flink源码解析之:Flink on k8s 客户端提交任务源码分析
  • 创建Java项目,并添加MyBatis包和驱动包
  • nginx 日志规范化意义及实现!
  • 回顾2024年重磅AI发布汇总
  • shodan5,参数使用,批量查找Mongodb未授权登录,jenkins批量挖掘
  • QT编辑框带行号
  • 迷你航拍高清智能无人机技术详解
  • 云服务器和物理服务器有区别吗?
  • docker使用简介
  • 【WRF数据处理】基于GIS4WRF插件将geotiff数据转为tiff(geogrid,WPS所需数据)
  • AI Agents - 自动化项目:计划、评估和分配
  • JAVA的设计模式都有那些
  • ppt演示如何计时?分享2个ppt使用技巧,轻松搞定ppt计时!
  • STM32 从0开始系统学习4 编写LED驱动
  • 基于Java语言的充电桩管理系统
  • DICOM标准:DICOM服务类详解,了解存储服务类、查询/检索服务类(Q/R Service Class)和工作流管理服务类等原理
  • 无人机协同控制技术详解!
  • pdf免费压缩软件 pdf文件压缩免费软件 软件工具方法
  • 人类借助AI发现第 52 个梅森素数
  • cloak斗篷伪装下的独立站
  • 被上传文件于后端的命名策略
  • Typora 、 Minio and PicGo 图床搭建
  • uniapp 图片bug(图片为线上地址,url不变,内容更新)
  • 红黑树模拟封装map和set