利用前向勾子获取神经网络中间层的输出并将其进行保存(示例详解)
代码示例:
# 激活字典,用于保存每次的中间特征
activation = {}
# 将 forward_hook 函数定义在 upsample_v2 外部
def forward_hook(name):
def hook(module, input, output):
activation[name] = output.detach()
return hook
def upsample_v2(in_channels, out_channels, upscale, kernel_size=3):
layers = []
# Define mid channel stages (three times reduction)
mid_channels = [256, 128, 64] # 512 32 32 -> 256 64 64 -> 128 128 128 -> 64 256 256 -> 2 256 256
scale_factor_per_step = upscale ** (1/3) # Calculate the scaling for each step
current_in_channels = in_channels
# Upsample and reduce channels in 3 steps
for step, mid_channel in enumerate(mid_channels):
# Conv layer to reduce number of channels
conv = nn.Conv2d(current_in_channels, mid_channel, kernel_size=kernel_size, padding=1, bias=False)
nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')
layers.append(conv)
# ReLU activation
relu = nn.ReLU()
layers.append(relu)
# Upsampling layer
up = nn.Upsample(scale_factor=scale_factor_per_step, mode='bilinear', align_corners=True)
layers.append(up)
layers[-1].register_forward_hook(forward_hook(f'step_{step}'))
# Update current in_channels for the next layer
current_in_channels = mid_channel
conv = nn.Conv2d(current_in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)
nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')
layers.append(conv)
return nn.Sequential(*layers)
def forward_hook(name):
def hook(module, input, output):
activation[name] = output.detach()
return hook
forward_hook布置了抓取函数。其中,module代表你下面勾的那一层,input代表那一层的输入,output定义那一层的输出,我们常常只使用output。
layers[-1].register_forward_hook(forward_hook(f'step_{step}'))
这里定义了我需要捕获的那一层,layers[-1]代表我要捕获当前layers的最后一层,即上采用层,由于循环了三次,所以最后勾取的应当是三份中间层输出。