当前位置: 首页 > article >正文

加载预训练权重时不匹配

场景

复现Rethinking the Learning Paradigm for Dynamic Facial Expression Recognition这篇论文时,加载已经训练好的.pt文件进行推理,发现准确率很低。利用下面两行代码加载预训练的权重:

weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict, strict=False)#model.load_state_dict()中的strict = False,但是这是适用于权重参数字典中key不相同的情况

提示如下:

_IncompatibleKeys(missing_keys=[‘features.0.0.weight’, ‘features.0.1.weight’, ‘features.0.1.bias’, ‘features.0.1.running_mean’, ‘features.0.1.running_var’, ‘features.1.0.conv1.0.weight’, ‘features.1.0.conv1.1.weight’, ‘features.1.0.conv1.1.bias’, ‘features.1.0.conv1.1.running_mean’, ‘features.1.0.conv1.1.running_var’, ‘features.1.0.conv2.0.weight’, ‘features.1.0.conv2.1.weight’, ‘features.1.0.conv2.1.bias’, ‘features.1.0.conv2.1.running_mean’, ‘features.1.0.conv2.1.running_var’, ‘features.1.1.conv1.0.weight’, ‘features.1.1.conv1.1.weight’, ‘features.1.1.conv1.1.bias’, ‘features.1.1.conv1.1.running_mean’, ‘features.1.1.conv1.1.running_var’, ‘features.1.1.conv2.0.weight’, ‘features.1.1.conv2.1.weight’, ‘features.1.1.conv2.1.bias’, ‘features.1.1.conv2.1.running_mean’, ‘features.1.1.conv2.1.running_var’, ‘features.2.0.conv1.0.weight’, ‘features.2.0.conv1.1.weight’, ‘features.2.0.conv1.1.bias’, ‘features.2.0.conv1.1.running_mean’, ‘features.2.0.conv1.1.running_var’, ‘features.2.0.conv2.0.weight’, ‘features.2.0.conv2.1.weight’, ‘features.2.0.conv2.1.bias’, ‘features.2.0.conv2.1.running_mean’, ‘features.2.0.conv2.1.running_var’, ‘features.2.0.downsample.0.weight’, ‘features.2.0.downsample.1.weight’, ‘features.2.0.downsample.1.bias’, ‘features.2.0.downsample.1.running_mean’, ‘features.2.0.downsample.1.running_var’, ‘features.2.1.conv1.0.weight’, ‘features.2.1.conv1.1.weight’, ‘features.2.1.conv1.1.bias’, ‘features.2.1.conv1.1.running_mean’, ‘features.2.1.conv1.1.running_var’, ‘features.2.1.conv2.0.weight’, ‘features.2.1.conv2.1.weight’, ‘features.2.1.conv2.1.bias’, ‘features.2.1.conv2.1.running_mean’, ‘features.2.1.conv2.1.running_var’, ‘features.3.0.conv1.0.weight’, ‘features.3.0.conv1.1.weight’, ‘features.3.0.conv1.1.bias’, ‘features.3.0.conv1.1.running_mean’, ‘features.3.0.conv1.1.running_var’, ‘features.3.0.conv2.0.weight’, ‘features.3.0.conv2.1.weight’, ‘features.3.0.conv2.1.bias’, ‘features.3.0.conv2.1.running_mean’, ‘features.3.0.conv2.1.running_var’, ‘features.3.0.downsample.0.weight’, ‘features.3.0.downsample.1.weight’, ‘features.3.0.downsample.1.bias’, ‘features.3.0.downsample.1.running_mean’, ‘features.3.0.downsample.1.running_var’, ‘features.3.1.conv1.0.weight’, ‘features.3.1.conv1.1.weight’, ‘features.3.1.conv1.1.bias’, ‘features.3.1.conv1.1.running_mean’, ‘features.3.1.conv1.1.running_var’, ‘features.3.1.conv2.0.weight’, ‘features.3.1.conv2.1.weight’, ‘features.3.1.conv2.1.bias’, ‘features.3.1.conv2.1.running_mean’, ‘features.3.1.conv2.1.running_var’, ‘features.4.0.conv1.0.weight’, ‘features.4.0.conv1.1.weight’, ‘features.4.0.conv1.1.bias’, ‘features.4.0.conv1.1.running_mean’, ‘features.4.0.conv1.1.running_var’, ‘features.4.0.conv2.0.weight’, ‘features.4.0.conv2.1.weight’, ‘features.4.0.conv2.1.bias’, ‘features.4.0.conv2.1.running_mean’, ‘features.4.0.conv2.1.running_var’, ‘features.4.0.downsample.0.weight’, ‘features.4.0.downsample.1.weight’, ‘features.4.0.downsample.1.bias’, ‘features.4.0.downsample.1.running_mean’, ‘features.4.0.downsample.1.running_var’, ‘features.4.1.conv1.0.weight’, ‘features.4.1.conv1.1.weight’, ‘features.4.1.conv1.1.bias’, ‘features.4.1.conv1.1.running_mean’, ‘features.4.1.conv1.1.running_var’, ‘features.4.1.conv2.0.weight’, ‘features.4.1.conv2.1.weight’, ‘features.4.1.conv2.1.bias’, ‘features.4.1.conv2.1.running_mean’, ‘features.4.1.conv2.1.running_var’, ‘lstm.weight_ih_l0’, ‘lstm.weight_hh_l0’, ‘lstm.bias_ih_l0’, ‘lstm.bias_hh_l0’, ‘lstm.weight_ih_l0_reverse’, ‘lstm.weight_hh_l0_reverse’, ‘lstm.bias_ih_l0_reverse’, ‘lstm.bias_hh_l0_reverse’, ‘lstm.weight_ih_l1’, ‘lstm.weight_hh_l1’, ‘lstm.bias_ih_l1’, ‘lstm.bias_hh_l1’, ‘lstm.weight_ih_l1_reverse’, ‘lstm.weight_hh_l1_reverse’, ‘lstm.bias_ih_l1_reverse’, ‘lstm.bias_hh_l1_reverse’, ‘to_qkv.weight’, ‘norm.weight’, ‘norm.bias’, ‘norm.mean_weight’, ‘norm.var_weight’, ‘pwconv.weight’, ‘pwconv.bias’, ‘fc.weight’, ‘fc.bias’], unexpected_keys=[‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])

原因探究与解决方案

查看预训练的权重字典的键值发现,这些键对于模型来说都是多余的

weights_dict.keys()

dict_keys([‘epoch’, ‘state_dict’, ‘best_wa’, ‘best_ua’, ‘optimizer’, ‘args’])

进一步探究发现,模型中不兼容的键都在'state_dict'中,我们通过weights_dict["state_dict"]提取权重字典中与模型适配的键值,修改后的代码如下:

weights_dict = torch.load('/data2/liuxu/attribute/M3DFEL/outputs/DFEW-[10-29]-[14:29]/model_best.pth', map_location='cuda:7')
mymodel.load_state_dict(weights_dict["state_dict"], strict=False)

<All keys matched successfully>

此时我们发现预训练的模型可以很好的进行推理了!!!

device = "cuda:7"
mymodel.to(device)

test_dataloader = create_dataloader(args, "test")

mymodel.eval()

all_pred, all_target = [], []
# 模型推理
for i, (images, target) in enumerate(test_dataloader):
    images = images.to(device)
    target = target.to(device)
    with torch.no_grad():
        output = mymodel(images)
    pred = torch.argmax(output, 1).cpu().detach().numpy()
    target = target.cpu().numpy()
    print(pred,target)
    break

[1 2 1 2 3 1 2 2 4 4 3 6 1 2 3 6] [1 2 1 3 3 1 6 4 4 2 3 6 0 4 3 0]


http://www.kler.cn/a/157460.html

相关文章:

  • Java图片拼接
  • WordPress 去除?v= 动态后缀
  • YOLOv8全解析:高效、精准的目标检测新时代——创新架构与性能提升
  • 基底展开(Expansion in a Basis):概念、推导与应用 (中英双语)
  • LeetCode:101. 对称二叉树
  • 【Python】pandas库---数据分析
  • 数据库事务
  • C/C++ 原生套接字抓取FTP数据包
  • 【Cadence Allegro17.4】
  • nginx部署和安装-后端程序多端口访问-后端代理设置
  • Python实现FA萤火虫优化算法优化卷积神经网络分类模型(CNN分类算法)项目实战
  • 基于Eclipse+Mysql+Tomcat开发的 教学评价管理系统
  • using meta-SQL 使用元SQL 六
  • mfc项目设置软件版本
  • Unity Canvas、Canvas Scaler、Graphic Raycaster、EventSystem 组件详解
  • 分享 | 顶刊高质量论文插图配色(含RGB值及16进制HEX码)(第一期)
  • 基于SSM的图书馆管理系统的设计与实现
  • 【论文阅读】1 SkyChain:一个深度强化学习的动态区块链分片系统
  • 【滤波第二期】中值滤波的原理和C代码
  • 【开源】基于Vue和SpringBoot的音乐偏好度推荐系统
  • 跨网文件摆渡系统:安全、可控的数字传输桥梁
  • MyBatis查询优化:枚举在条件构建中的妙用
  • 写给初学者的 HarmonyOS 教程 -- 状态管理(@State/@Prop/@Link 装饰器)
  • linux 应用开发笔记---【标准I/O库/文件属性及目录】
  • PTA 一维数组7-3出生年(本题请你根据要求,自动填充“我出生于y年,直到x岁才遇到n个数字都不相同的年份”这句话)
  • C++算法入门练习——最大连续子序列和