YOLOv5改进——添加SimAM注意力机制
目录
一、SimAM注意力机制核心代码
二、修改common.py
三、修改yolo.py
四、建立yaml文件
五、验证
六、出现RuntimeError问题
七、CA注意力机制
一、SimAM注意力机制核心代码
在models文件夹下新建modules文件夹,在modules文件夹下新建一个py文件。这里为simam.py。复制以下代码到文件里面。
import torch
import torch.nn as nn
class SimAM(torch.nn.Module):
def __init__(self, channels = None, e_lambda = 1e-4):
super(SimAM, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
注意:很多改进教程都是将代码直接复制到common.py文件,如果改进机制多了容易造成混乱。建议创建一个modules文件夹,将改进机制放里面方便管理。
二、修改common.py
在common.py文件中,在前面的部分添加以下代码,导入GhostV2.py的内容:
from models.modules.simam import *
三、修改yolo.py
在yolo.py文件中,在导入common模块的上面一行添加以下代码,导入GhostV2.py的内容:
from models.modules.simam import *
注意:这里位置不要搞错,不然可能会找不到导入的模块。
如下图所示: