当前位置: 首页 > article >正文

Pytorch注意力机制应用到具体网络方法(闭眼都会版)

文章目录

  • 以YoloV4-tiny为例
    • 要加入的注意力机制代码
    • 模型中插入注意力机制

以YoloV4-tiny为例

在这里插入图片描述
解释一下各个部分:

  • 最左边这部分为主干提取网络,功能为特征提取
  • 中间这边部分为FPN,功能是加强特征提取
  • 最后一部分为yolo head,功能为获得我们具体的一个预测结果

需要明白几个点:

  • 注意力机制模块是一个即插即用的模块,理论上是可以添加到任何一个特征图后面
  • 但是,不建议添加到主干部分(即最左边的那部分),主干部分所用的特征是我们后面处理所用的基础,故不建议添加到主干部分
  • 如果添加到主干部分,由于注意力机制模块 它的权值模块是随机初始化的,那主干部分的权值就被破坏了,最开始提取出来的特征就不好用了。
  • 故建议把注意力机制模块添加到主干以外的部分

本节把注意力机制添加到加强网络里面,即上图的中间部分。
添加注意力机制可以添加到上图标注的部分。

要加入的注意力机制代码

这一部分为要加入的注意力机制模块,文件名为attention.py

import torch
from torch import nn
# 通道注意力机制
class channel_attention(nn.Module):
    def __init__(self,channel,ration=16):   #因为要进行全连接,故需要传入通道数量,及缩放比例
        super(channel_attention,self).__init__()  #初始化
        #定义最大池化层
        self.max_pool = nn.AdaptiveMaxPool2d(1) #输出层的高和宽是1
        #定义平均池化
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            #定义第一次全连接
            nn.Linear(channel,channel // ration ,False),
            nn.ReLU(),
            # 定义第二次全连接
            nn.Linear(channel//ration,channel,False)
        )
        #由于图中的通道注意力机制是连个全连接层相加之后再取sigmoid
        self.sigmoid=nn.Sigmoid()

    #前传部分
    def forward(self,x):
        b,c,h,w=x.size()
        #首先对输入进来的x先进行一个全局最大池化 在进行一个全局平均池化
        max_pool_out=self.max_pool(x).view([b,c])
        avg_pool_out=self.avg_pool(x).view([b,c])
        #然后对两次池化后的结果用共享的全连接层fc进行处理
        max_fc_out=self.fc(max_pool_out)
        avg_fc_out=self.fc(avg_pool_out)
        #最后将上面的两个结果进行相加
        out=max_fc_out + avg_fc_out
        out=self.sigmoid(out).view([b,c,1,1])
        #print(out)
        return out * x
# 空间注意力机制
class spacial_attention(nn.Module):
    def __init__(self,kernel_size=7):   #空间注意力没有通道数,故不用传入channel和ration
        #但是空间注意力会进行一次卷积,故我们需要关注卷积核大小,一般为3或7
        super(spacial_attention,self).__init__()  #初始化
        padding=7//2  #卷积核大小整除输入通道数
        self.conv=nn.Conv2d(2,1,kernel_size,1,padding,bias=False)
        #由图可知输入通道数是2,输出通道数为1,卷积核大小默认设置为7,步长为1,因为不需要压缩特征层阿高和宽

        #由于图中的通道注意力机制是连个全连接层相加之后再取sigmoid
        self.sigmoid=nn.Sigmoid()
    #空间注意力机制前传部分
    def forward(self,x):
        b,c,h,w=x.size()
        max_pool_out,_= torch.max(x,dim=1,keepdim=True)#需要把通道这一维度保留下来,故设置keepdim为True
        #对于pytorch来讲,它的通道是在第一维度,也就是batchsize后面的那个维度故定义dim为1
        mean_pool_out = torch.mean(x,dim = 1,keepdim=True)
        #对最大值和平均值进行一个堆叠
        pool_out = torch.cat([max_pool_out, mean_pool_out],dim=1)
        #对堆叠后的结果取一个卷积
        out=self.conv(pool_out)
        out=self.sigmoid(out)
        print(out)
        return out * x

#把空间注意力机制和通道注意力机制进行一个融合
class Cbam(nn.Module):
    def __init__(self,channel,ratio=16,kernel_size=7):
        super(Cbam,self).__init__()
        #调用已经定义好的2个注意力机制
        self.channel_attention=channel_attention(channel,ratio)
        self.spacial_attention = spacial_attention(kernel_size)
    #融合后机制的前传部分
    def forward(self,x):
        x=self.channel_attention(x)
        x=self.spacial_attention(x)
        return x

在模型文件(yolo.py)中,首行添加如下部分

from .attention import se_block,cbam_block,eca_block
attention_blocks=[se_block,cbam_block,eca_block]
为何要设置成上面的形式?
为了方便调用,到时候可以直接编写下面的代码调用具体的注意力机制模块
attention_blocks[0]

之后,需要找到yolo.py里面的模型主体部分,大概形式如下代码

class YoloBody(nn.Module):
	def __init__(self,anchors_mask,num_classes,phi=0)
	#在原来的代码上只是添加了phi,代表我们选用的注意力机制模块,默认情况下为0
		super(YoloBody, self).__init__()
	        self.backbone       = darknet53_tiny(None)
	
	        self.conv_for_P5    = BasicConv(512,256,1)
	        self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
	
	        self.upsample       = Upsample(256,128)
	        self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
	        #下面这部分为自己填写
	     	self.phi    = phi  #这个是自己添加的
	        if 1 <= self.phi and self.phi <= 3:
            self.feat1_att      = attention_block[self.phi - 1](256)  #通道数为256
            self.feat2_att      = attention_block[self.phi - 1](512)#通道数为512
            self.upsample_att   = attention_block[self.phi - 1](128)#通道数为128
            #通道数到底是多少看这个模型的前传部分的通道数为多少
    def forward(self, x):
		#---------------------------------------------------#
		#   生成CSPdarknet53_tiny的主干模型
		#   feat1的shape为26,26,256
		#   feat2的shape为13,13,512
		#---------------------------------------------------#
		feat1, feat2 = self.backbone(x)
		#下面代码为自己填写
		if 1 <= self.phi and self.phi <= 3:#如果满足条件就添加具体的注意力机制
		    feat1 = self.feat1_att(feat1)
		    feat2 = self.feat2_att(feat2)
		#下面代码模型自带
		# 13,13,512 -> 13,13,256
		P5 = self.conv_for_P5(feat2)
		# 13,13,256 -> 13,13,512 -> 13,13,255
		out0 = self.yolo_headP5(P5) 
		
		# 13,13,256 -> 13,13,128 -> 26,26,128
		P5_Upsample = self.upsample(P5)
		# 26,26,256 + 26,26,128 -> 26,26,384
		#上面代码模型自带,下面代码自己编写
		if 1 <= self.phi and self.phi <= 3:
		    P5_Upsample = self.upsample_att(P5_Upsample)
		 #下面代码模型自带
		P4 = torch.cat([P5_Upsample,feat1],axis=1)
		
		# 26,26,384 -> 26,26,256 -> 26,26,255
		out1 = self.yolo_headP4(P4)
		
		return out0, out1


模型中插入注意力机制


http://www.kler.cn/a/457428.html

相关文章:

  • PostgreSQL对称between比较运算
  • C++模板相关概念汇总
  • 实现单例模式的五种方式
  • springboot集成qq邮箱服务
  • springboot实战(19)(条件分页查询、PageHelper、MYBATIS动态SQL、mapper映射配置文件、自定义类封装分页查询数据集)
  • Redis的生态系统和社区支持
  • vue导入导出excel、设置单元格文字颜色、背景色、合并单元格(使用xlsx-js-style库)
  • R 语言科研绘图第 11 期 --- 柱状图-基础
  • Linux -- 从抢票逻辑理解线程互斥
  • 酷瓜云课堂(内网版)v1.1.8 发布,局域网在线学习平台方案
  • 关于新手学习React的一些忠告
  • Selenium+Java(21):Jenkins发送邮件报错Not sent to the following valid addresses解决方案
  • 最新版Chrome浏览器加载ActiveX控件技术——alWebPlugin中间件V2.0.28-迎春版发布
  • 程序员学习方针
  • HashMap
  • 如果用Bert模型训练,epochs不宜过大
  • 使用 uni-app 开发的微信小程序中,如何在从 B 页面回来时,重新拉取数据?
  • 【LC】3046. 分割数组
  • 计算机体系结构期末复习4:多处理器缓存一致性(cache一致性)
  • UE5 丧尸类杂兵的简单AI
  • 【Spring MVC】第一站:Spring MVC介绍配置基本原理
  • 人工智能之基于阿里云进行人脸特征检测部署
  • UnityURP 自定义PostProcess之深度图应用
  • Nginx的性能分析与调优简介
  • template<typename Func, typename = void> 在类模板中的应用
  • windows 上安装nginx , 启停脚本