当前位置: 首页 > article >正文

(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

在这里插入图片描述

文章目录

  • 1、Prompt Block
  • 2、代码实现

paper:PromptIR: Prompting for All-in-One Blind Image Restoration

Code:https://github.com/va1shn9v/PromptIR


1、Prompt Block

在解决现有图像恢复模型时,现有研究存在一些局限性: 现有的图像恢复模型通常针对特定的退化类型(如去噪、去雾、去雨)进行训练,这会缺乏泛化能力,难以适应多种退化类型和级别。此外,现有的多退化图像恢复模型通常需要知道输入图像的退化类型,才能选择合适的模型进行恢复,这在实际应用中都是不太现实的。最后,现有的多退化图像恢复模型需要为每种退化类型和级别训练单独的模型,这会导致训练负担过重,且难以在资源受限的平台(如移动设备和边缘设备)上部署。为此,这篇论文提出一种 Prompt Block,其通过引入可学习的提示参数,将退化相关的信息编码到网络中,从而引导网络进行自适应的图像恢复。

Prompt Block 可以分为两个部分:即 Prompt Generation Module(PGM)Prompt Interaction Module(PIM)。具体来说,PGM 的目标是根据输入图像的特征动态生成 prompt 参数,使其能够更好地适应不同的退化类型。而 PIM 通过将 prompt P 与输入特征沿通道维度进行拼接,然后通过 Transformer block 进行处理,实现特征与 prompt 的交互。

对于一个输入特征 X,Prompt Block 的实现过程:

Prompt Generation Module:

  1. 对输入特征进行全局平均池化 (GAP),得到特征向量 v。
  2. 使用 1x1 卷积层对特征向量进行降维,得到紧凑的特征向量。
  3. 对降维后的特征向量进行 softmax 操作,得到 prompt 权重 w。
  4. 使用 prompt 权重 w 对 prompt 组件 Pc 进行加权求和,得到输入条件 prompt P。

Prompt Interaction Module:

  1. 首先将 prompt P 与输入特征 Fl 沿通道维度进行拼接。
  2. 将拼接后的特征通过 Transformer block 进行处理。
  3. 最后将特征经两层卷积处理,输出特征即为经过 Prompt Block 调整后的特征。

Prompt Generation / Interaction Module 结构图:
在这里插入图片描述

2、代码实现

import torch
from torch import nn, einsum
import torch.nn.functional as F


class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt


if __name__ == '__main__':
    x = torch.randn(4, 3, 64, 64).cuda()
    model = PromptGenBlock(3, lin_dim=3).cuda()
    out = model(x)
    print(out.shape)


http://www.kler.cn/a/527932.html

相关文章:

  • Spring JDBC:简化数据库操作的利器
  • 6.二分算法
  • Ruby Dir 类和方法详解
  • 抽象类与抽象方法详解
  • 搜索与图论复习1
  • 消息队列篇--通信协议篇--TCP和UDP(3次握手和4次挥手,与Socket和webSocket的概念区别等)
  • leetcode 844 比较含退格的字符串
  • 利用 AMD Instinct™ MI300X 提升计算流体动力学性能
  • cf1000(div.2)
  • 微服务实战 原生态实现服务的发现与调用_如何发现应用的服务调用问题
  • 从 UTC 日期时间字符串获取 Unix 时间戳:C 和 C++ 中的挑战与解决方案
  • P1158
  • 19 压测和常用的接口优化方案
  • Python从0到100(八十六):神经网络-ShuffleNet通道混合轻量级网络的深入介绍
  • 【4Day创客实践入门教程】Day3 实战演练——桌面迷你番茄钟
  • 1/31每日
  • 事务02之锁机制
  • Linux中部署Yolov5详解
  • 嵌入式知识点总结 Linux驱动 (八)-Linux设备驱动
  • H. Mad City
  • 深度学习编译器的演进:从计算图到跨硬件部署的自动化之路
  • 《大数据时代“快刀”:Flink实时数据处理框架优势全解析》
  • 翻译: Dario Amodei 关于DeepSeek与出口管制一
  • (二)QT——按钮小程序
  • 本地运行大模型效果及配置展示
  • 牛客周赛 Round 77