如何提取神经网络中间层特征向量
方法一 钩子(Hook)函数
举个实例来讲解,假设搭建的神经网络结构为一个简单的CNN,代码见下:
class CNN(nn.Module): #n=101 (n+2p-f)/s+1
def __init__(self):
super(CNN, self).__init__()
self.cnn=nn.Sequential(
#Layer1
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0), #n=97
nn.BatchNorm2d(16),
nn.ReLU(),
# nn.Dropout2d(0.1),
nn.MaxPool2d(kernel_size=2, stride=2), #48
#Layer2
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0), #44
nn.BatchNorm2d(32),
nn.ReLU(),
# nn.Dropout2d(0.1),
nn.MaxPool2d(kernel_size=2, stride=2), #22
#Layer3
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=1, padding=0), #19
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 9
#Layer4
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=1, padding=0), # 6
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) #3
)
self.linear=nn.Sequential(
nn.Linear(3*3*128,256),
nn.ReLU(),
nn.Linear(256,4)
)
def forward(self, x):
x=self.cnn(x)
x = x.view(x.shape[0], -1)
x=self.linear(x)
return x
我们如果想要提取出全连接层中第一个线性层输出的特征向量,即linear中nn.Linear(3*3*128,256)输出的维度为256的特征向量,那么我们可以这样做:
# 定义hook函数
fmap_black0={}
def forward_hook0(model,inp,outp):
fmap_black0['input']=inp
fmap_black0['output'] = outp
model.linear[0].register_forward_hook(forward_hook0) # linear[0]代表运行linear中第一行的结果(每一行代表一层),特征向量即从网络中第一个FC输出的维度为128的向量
feature = tensor_list_train.append(fmap_black0["output"])
这样就可以将所需特征向量提取出来啦。
但是,该方法适用于结构较简单的神将网络,如果是比较复杂的网络,比如主干网络中还嵌套了其他网络(如下代码,只是举个例子,可以看到网络内嵌套了Feature_Net、TransformerModel等网络),但又想提取出嵌套网络中某一层的特征,就只能望洋兴叹了,因此下面提供方法二。
class TFFusion_Net_double(nn.Module):
def __init__(self, feature_dim=128, d_model=256, img_size=16, patch_size=16, embed_dim=128, depth=4, num_heads=4, output_dim=4):
super(TFFusion_Net_double, self).__init__()
self.feature_model = Feature_Net(feature_dim=feature_dim)
self.attention = TransformerModel(d_model=d_model)
self.xcit_model = XCiT(img_size=img_size,
patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=4,
qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=False)
self.Dense = nn.Sequential(
nn.Linear(embed_dim*4*4, 1024),
nn.Mish(),
nn.Linear(1024, 128),
nn.Mish(),
nn.Linear(128, output_dim)
)
def forward(self, x, y):
x, y = self.feature_model(x), self.feature_model(y) # (64,128)
mix = torch.cat((x, y), dim=1) # (64,256)
mix = torch.unsqueeze(mix, dim=1) # (64,1,256)
mix = self.attention(mix)
mix = mix.view(x.shape[0], 1, 16, 16)
mix = mix.expand(-1, 3, -1, -1)
mix = self.xcit_model(mix)
mix = mix[0]
mix = mix.contiguous().view(mix.shape[0], -1)
mix = self.Dense(mix)
return mix
方法二 torchextractor
Github开源地址:antoinebrl/torchextractor: Feature extraction made simple with torchextractor
Github中的讲解也很清楚,这里依然举一个实例来讲解:
import torchextractor as tx # PyTorch Intermediate Feature Extraction
if __name__ == '__main__':
# 下载网络模型
model = CNN_Net(output_dim=4)
model_weight = r'XXXXXXXX.pth'
model.load_state_dict(torch.load(model_weight))
model.to(device)
print(tx.list_module_names(model)) # 打印出网络结构
'''
['', 'cnn', 'cnn.0', 'cnn.1', 'cnn.2', 'cnn.3', 'cnn.4', 'cnn.5', 'cnn.6', 'cnn.7', 'cnn.8', 'cnn.9', 'cnn.10', 'cnn.11',
'fc', 'fc.0', 'fc.1', 'fc.2']
'''
model = tx.Extractor(model, "fc.1") # 想要得到全连接层中第1层的输出
model.eval()
# 下面的inputs可从使用的数据集中选取一个样本,根据实际数据集来获取
_, feature = model(inputs.float())
feature = feature['fc.1']
feature = feature.view(feature.shape[0],-1)[0]
feature = feature.cpu().detach().numpy()
# 保存特征向量
scipy.io.savemat('XXXXXX.mat', {'data':feature})
比较复杂,含有嵌套网络的结构也可以打印出来,如:
'''
['', 'feature_model', 'feature_model.cnn', 'feature_model.cnn.0', 'feature_model.cnn.1',
'feature_model.cnn.2', 'feature_model.cnn.3', 'feature_model.cnn.4', 'feature_model.cnn.5',
'feature_model.cnn.6', 'feature_model.cnn.7', 'feature_model.cnn.8', 'feature_model.cnn.9',
'feature_model.cnn.10', 'feature_model.cnn.11', 'feature_model.fc', 'timexer_model',
'timexer_model.self_attn', 'timexer_model.self_attn.self_attn',
'timexer_model.self_attn.self_attn.out_proj', 'timexer_model.self_attn.linear1',
'timexer_model.self_attn.dropout', 'timexer_model.self_attn.linear2',
'timexer_model.self_attn.norm1', 'timexer_model.self_attn.norm2',
'timexer_model.self_attn.dropout1', 'timexer_model.self_attn.dropout2',
'timexer_model.cross_attn', 'timexer_model.cross_attn.self_attn',
'timexer_model.cross_attn.self_attn.out_proj', 'timexer_model.cross_attn.linear1',
'timexer_model.cross_attn.dropout', 'timexer_model.cross_attn.linear2',
'timexer_model.cross_attn.norm1', 'timexer_model.cross_attn.norm2',
'timexer_model.cross_attn.dropout1', 'timexer_model.cross_attn.dropout2']
'''
需要提取哪一层的输出特征,直接按照上述实例将"fc.1"替换即可。