当前位置: 首页 > article >正文

每日Attention学习27——Patch-based Graph Reasoning

模块出处

[NC 25] [link] Graph-based context learning network for infrared small target detection


模块名称

Patch-based Graph Reasoning (PGR)


模块结构

在这里插入图片描述


模块特点
  • 使用图结构更好的捕捉特征的全局上下文
  • 将图结构与特征切片(Patching)相结合,从而促进全局/局部特征互补

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class graph(nn.Module):
    def __init__(self, p2=4, nIn=64, N=16):
        super(graph, self).__init__()
        self.p2 = p2
        self.N = N
        self.conv30 = nn.Sequential(
            nn.Conv2d(nIn, self.N, kernel_size=3, stride=1, padding=1, groups=1),
            nn.ReLU(inplace=True)
        )
        self.conv10 = nn.Sequential(
            nn.Conv1d(nIn, nIn, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True)
        )
        self.conv11 = nn.Sequential(
            nn.Conv1d(self.N, self.N, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True)
        )
        self.adaptivemax = nn.AdaptiveAvgPool2d((8, 8))
        self.conv12 = nn.Sequential(
            nn.Conv1d(p2 ** 2, p2, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv1d(p2, p2, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv1d(p2, p2 ** 2, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def ADP_weight(self, x):
        b, C, H, W = x.shape
        fg = self.adaptivemax(x)  
        fg1 = fg.view(b, C, self.p2 ** 2)  
        fg1 = torch.transpose(fg1, 1, 2)  
        fg2 = self.conv12(fg1)  
        fg3 = fg2.unsqueeze(-1).unsqueeze(-1)
        return fg3

    def graph_convolution(self, fs, x):
        b, C, H, W = x.shape
        h, w = H // self.p2, W // self.p2
        L = h * w
        B = self.conv30(fs)  
        B1 = B.view(-1, self.N, L)  
        fs1 = fs.view(-1, C, L)  
        fs1 = torch.transpose(fs1, 1, 2) 
        fs2 = torch.bmm(B1, fs1)  
        fs3 = self.conv11(fs2)  
        fs5 = self.conv10(torch.transpose(fs3, 1, 2))  
        fs6 = torch.bmm(torch.transpose(B1, 1, 2), torch.transpose(fs5, 1, 2))
        fs6 = torch.transpose(fs6, 1, 2) 
        fs6 = fs6.view(b, self.p2 ** 2, C, h, w) 
        return fs6

    def forward(self, fs, x):
        fs6 = self.graph_convolution(fs, x)
        weight = self.ADP_weight(x)
        out = weight * fs6
        return out
    

class PGR(nn.Module):
    def __init__(self, p2=4, nIn=32, nOut=32, add=True):
        super(PGR, self).__init__()
        self.p2 = p2
        self.N = nIn // 4
        self.add = add
        self.graph0 = graph(p2, nIn, self.N)
        self.conv31 = nn.Sequential(
            nn.Conv2d(nOut, nOut, kernel_size=1, stride=1),
            nn.BatchNorm2d(nOut),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        b, C, H, W = x.shape
        h, w = H // self.p2, W // self.p2
        L = h * w
        fs = torch.zeros((b, self.p2 ** 2, C, h, w)).cuda()
        for i in range(1, self.p2 + 1):
            for j in range(1, self.p2 + 1):
                fs[:, i * j - 1, :, :, :] = x[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w]
        fs = fs.view(b * self.p2 ** 2, C, h, w)
        fs6 = self.graph0(fs, x)
        out = torch.zeros_like(x)
        for i in range(1, self.p2 + 1):
            for j in range(1, self.p2 + 1):
                out[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w] = fs6[:, i * j - 1, :, :, :]
        out = self.conv31(out)
        if self.add:
            out = out + x
        return out


if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44]).cuda()
    pgr = PGR(p2=8, nIn=64, nOut=64).cuda()
    out = pgr(x)
    print(out.shape) # [1, 64, 44, 44]


http://www.kler.cn/a/588533.html

相关文章:

  • 【从零开始学习计算机科学】软件工程(六)软件质量
  • Docker基础知识介绍
  • 【Python+HTTP接口】POST请求不同请求头构造
  • 【ASMbits--常用算术运算指令】
  • 深入解析 FID:深度学习生成模型评价指标
  • pyQT学习笔记——Qt常用组件与绘图类的使用指南
  • 【商城实战(36)】UniApp性能飞升秘籍:从渲染到编译的深度优化
  • 使用memmove优化插入排序
  • 软件架构设计习题及复习
  • 计算机网络——NAT
  • 【Linux】Socket 编程 TCP
  • 《Python深度学习》第四讲:计算机视觉中的深度学习
  • 在Simulink中将Excel数据导入可变负载模块的方法介绍
  • 工程化与框架系列(30)--前端日志系统实现
  • cursor全栈网页开发最合适的技术架构和开发语言
  • JVM系统变量的妙用
  • 树莓派 连接 PlutoSDR 教程
  • Typedef 与enum的使用
  • 【人工智能基础2】人工神经网络、卷积神经网络基础、循环神经网络、长短时记忆网络
  • [蓝桥杯]花束搭配【算法赛】