加载预训练权重时不匹配
场景
复现Rethinking the Learning Paradigm for Dynamic Facial Expression Recognition这篇论文时,加载已经训练好的.pt
文件进行推理,发现准确率很低。利用下面两行代码加载预训练的权重:
weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict, strict=False)#model.load_state_dict()中的strict = False,但是这是适用于权重参数字典中key不相同的情况
提示如下:
_IncompatibleKeys(missing_keys=[‘features.0.0.weight’, ‘features.0.1.weight’, ‘features.0.1.bias’, ‘features.0.1.running_mean’, ‘features.0.1.running_var’, ‘features.1.0.conv1.0.weight’, ‘features.1.0.conv1.1.weight’, ‘features.1.0.conv1.1.bias’, ‘features.1.0.conv1.1.running_mean’, ‘features.1.0.conv1.1.running_var’, ‘features.1.0.conv2.0.weight’, ‘features.1.0.conv2.1.weight’, ‘features.1.0.conv2.1.bias’, ‘features.1.0.conv2.1.running_mean’, ‘features.1.0.conv2.1.running_var’, ‘features.1.1.conv1.0.weight’, ‘features.1.1.conv1.1.weight’, ‘features.1.1.conv1.1.bias’, ‘features.1.1.conv1.1.running_mean’, ‘features.1.1.conv1.1.running_var’, ‘features.1.1.conv2.0.weight’, ‘features.1.1.conv2.1.weight’, ‘features.1.1.conv2.1.bias’, ‘features.1.1.conv2.1.running_mean’, ‘features.1.1.conv2.1.running_var’, ‘features.2.0.conv1.0.weight’, ‘features.2.0.conv1.1.weight’, ‘features.2.0.conv1.1.bias’, ‘features.2.0.conv1.1.running_mean’, ‘features.2.0.conv1.1.running_var’, ‘features.2.0.conv2.0.weight’, ‘features.2.0.conv2.1.weight’, ‘features.2.0.conv2.1.bias’, ‘features.2.0.conv2.1.running_mean’, ‘features.2.0.conv2.1.running_var’, ‘features.2.0.downsample.0.weight’, ‘features.2.0.downsample.1.weight’, ‘features.2.0.downsample.1.bias’, ‘features.2.0.downsample.1.running_mean’, ‘features.2.0.downsample.1.running_var’, ‘features.2.1.conv1.0.weight’, ‘features.2.1.conv1.1.weight’, ‘features.2.1.conv1.1.bias’, ‘features.2.1.conv1.1.running_mean’, ‘features.2.1.conv1.1.running_var’, ‘features.2.1.conv2.0.weight’, ‘features.2.1.conv2.1.weight’, ‘features.2.1.conv2.1.bias’, ‘features.2.1.conv2.1.running_mean’, ‘features.2.1.conv2.1.running_var’, ‘features.3.0.conv1.0.weight’, ‘features.3.0.conv1.1.weight’, ‘features.3.0.conv1.1.bias’, ‘features.3.0.conv1.1.running_mean’, ‘features.3.0.conv1.1.running_var’, ‘features.3.0.conv2.0.weight’, ‘features.3.0.conv2.1.weight’, ‘features.3.0.conv2.1.bias’, ‘features.3.0.conv2.1.running_mean’, ‘features.3.0.conv2.1.running_var’, ‘features.3.0.downsample.0.weight’, ‘features.3.0.downsample.1.weight’, ‘features.3.0.downsample.1.bias’, ‘features.3.0.downsample.1.running_mean’, ‘features.3.0.downsample.1.running_var’, ‘features.3.1.conv1.0.weight’, ‘features.3.1.conv1.1.weight’, ‘features.3.1.conv1.1.bias’, ‘features.3.1.conv1.1.running_mean’, ‘features.3.1.conv1.1.running_var’, ‘features.3.1.conv2.0.weight’, ‘features.3.1.conv2.1.weight’, ‘features.3.1.conv2.1.bias’, ‘features.3.1.conv2.1.running_mean’, ‘features.3.1.conv2.1.running_var’, ‘features.4.0.conv1.0.weight’, ‘features.4.0.conv1.1.weight’, ‘features.4.0.conv1.1.bias’, ‘features.4.0.conv1.1.running_mean’, ‘features.4.0.conv1.1.running_var’, ‘features.4.0.conv2.0.weight’, ‘features.4.0.conv2.1.weight’, ‘features.4.0.conv2.1.bias’, ‘features.4.0.conv2.1.running_mean’, ‘features.4.0.conv2.1.running_var’, ‘features.4.0.downsample.0.weight’, ‘features.4.0.downsample.1.weight’, ‘features.4.0.downsample.1.bias’, ‘features.4.0.downsample.1.running_mean’, ‘features.4.0.downsample.1.running_var’, ‘features.4.1.conv1.0.weight’, ‘features.4.1.conv1.1.weight’, ‘features.4.1.conv1.1.bias’, ‘features.4.1.conv1.1.running_mean’, ‘features.4.1.conv1.1.running_var’, ‘features.4.1.conv2.0.weight’, ‘features.4.1.conv2.1.weight’, ‘features.4.1.conv2.1.bias’, ‘features.4.1.conv2.1.running_mean’, ‘features.4.1.conv2.1.running_var’, ‘lstm.weight_ih_l0’, ‘lstm.weight_hh_l0’, ‘lstm.bias_ih_l0’, ‘lstm.bias_hh_l0’, ‘lstm.weight_ih_l0_reverse’, ‘lstm.weight_hh_l0_reverse’, ‘lstm.bias_ih_l0_reverse’, ‘lstm.bias_hh_l0_reverse’, ‘lstm.weight_ih_l1’, ‘lstm.weight_hh_l1’, ‘lstm.bias_ih_l1’, ‘lstm.bias_hh_l1’, ‘lstm.weight_ih_l1_reverse’, ‘lstm.weight_hh_l1_reverse’, ‘lstm.bias_ih_l1_reverse’, ‘lstm.bias_hh_l1_reverse’, ‘to_qkv.weight’, ‘norm.weight’, ‘norm.bias’, ‘norm.mean_weight’, ‘norm.var_weight’, ‘pwconv.weight’, ‘pwconv.bias’, ‘fc.weight’, ‘fc.bias’], unexpected_keys=[‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])
原因探究与解决方案
查看预训练的权重字典的键值发现,这些键对于模型来说都是多余的
weights_dict.keys()
dict_keys([‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])
进一步探究发现,模型中不兼容的键都在'state_dict'
中,我们通过weights_dict["state_dict"]
提取权重字典中与模型适配的键值,修改后的代码如下:
weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict["state_dict"], strict=False)
<All keys matched successfully>
此时我们发现预训练的模型可以很好的进行推理了!!!
device = "cuda:7"
mymodel.to(device)
test_dataloader = create_dataloader(args, "test")
mymodel.eval()
all_pred, all_target = [], []
# 模型推理
for i, (images, target) in enumerate(test_dataloader):
images = images.to(device)
target = target.to(device)
with torch.no_grad():
output = mymodel(images)
pred = torch.argmax(output, 1).cpu().detach().numpy()
target = target.cpu().numpy()
print(pred,target)
break
[1 2 1 2 3 1 2 2 4 4 3 6 1 2 3 6] [1 2 1 3 3 1 6 4 4 2 3 6 0 4 3 0]