每日Attention学习27——Patch-based Graph Reasoning
模块出处
[NC 25] [link] Graph-based context learning network for infrared small target detection
模块名称
Patch-based Graph Reasoning (PGR)
模块结构
模块特点
- 使用图结构更好的捕捉特征的全局上下文
- 将图结构与特征切片(Patching)相结合,从而促进全局/局部特征互补
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class graph(nn.Module):
def __init__(self, p2=4, nIn=64, N=16):
super(graph, self).__init__()
self.p2 = p2
self.N = N
self.conv30 = nn.Sequential(
nn.Conv2d(nIn, self.N, kernel_size=3, stride=1, padding=1, groups=1),
nn.ReLU(inplace=True)
)
self.conv10 = nn.Sequential(
nn.Conv1d(nIn, nIn, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)
self.conv11 = nn.Sequential(
nn.Conv1d(self.N, self.N, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)
self.adaptivemax = nn.AdaptiveAvgPool2d((8, 8))
self.conv12 = nn.Sequential(
nn.Conv1d(p2 ** 2, p2, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(p2, p2, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(p2, p2 ** 2, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
def ADP_weight(self, x):
b, C, H, W = x.shape
fg = self.adaptivemax(x)
fg1 = fg.view(b, C, self.p2 ** 2)
fg1 = torch.transpose(fg1, 1, 2)
fg2 = self.conv12(fg1)
fg3 = fg2.unsqueeze(-1).unsqueeze(-1)
return fg3
def graph_convolution(self, fs, x):
b, C, H, W = x.shape
h, w = H // self.p2, W // self.p2
L = h * w
B = self.conv30(fs)
B1 = B.view(-1, self.N, L)
fs1 = fs.view(-1, C, L)
fs1 = torch.transpose(fs1, 1, 2)
fs2 = torch.bmm(B1, fs1)
fs3 = self.conv11(fs2)
fs5 = self.conv10(torch.transpose(fs3, 1, 2))
fs6 = torch.bmm(torch.transpose(B1, 1, 2), torch.transpose(fs5, 1, 2))
fs6 = torch.transpose(fs6, 1, 2)
fs6 = fs6.view(b, self.p2 ** 2, C, h, w)
return fs6
def forward(self, fs, x):
fs6 = self.graph_convolution(fs, x)
weight = self.ADP_weight(x)
out = weight * fs6
return out
class PGR(nn.Module):
def __init__(self, p2=4, nIn=32, nOut=32, add=True):
super(PGR, self).__init__()
self.p2 = p2
self.N = nIn // 4
self.add = add
self.graph0 = graph(p2, nIn, self.N)
self.conv31 = nn.Sequential(
nn.Conv2d(nOut, nOut, kernel_size=1, stride=1),
nn.BatchNorm2d(nOut),
nn.ReLU(inplace=True)
)
def forward(self, x):
b, C, H, W = x.shape
h, w = H // self.p2, W // self.p2
L = h * w
fs = torch.zeros((b, self.p2 ** 2, C, h, w)).cuda()
for i in range(1, self.p2 + 1):
for j in range(1, self.p2 + 1):
fs[:, i * j - 1, :, :, :] = x[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w]
fs = fs.view(b * self.p2 ** 2, C, h, w)
fs6 = self.graph0(fs, x)
out = torch.zeros_like(x)
for i in range(1, self.p2 + 1):
for j in range(1, self.p2 + 1):
out[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w] = fs6[:, i * j - 1, :, :, :]
out = self.conv31(out)
if self.add:
out = out + x
return out
if __name__ == '__main__':
x = torch.randn([1, 64, 44, 44]).cuda()
pgr = PGR(p2=8, nIn=64, nOut=64).cuda()
out = pgr(x)
print(out.shape) # [1, 64, 44, 44]