当前位置: 首页 > article >正文

pytorch——保存‘类别名与类别数量’到权值文件中

前言

不知道大家有没有像我一样,每换一次不一样的模型,就要输入不同的num_classes和name_classes,反正我是很头疼诶,尤其是项目里面不止一个模型的时候,更新的时候看着就很头疼,然后就想着直接输入模型权值文件的path该多好,然后我就搞起来了。

在自己的类中加入想要加入数据信息

class your_nets(nn.Module):
    def __init__(self, num_classes = 21,name_classes=None):
        super(your_nets, self).__init__()
        self.num_classes = num_classes
        self.name_classes = name_classes

训练过程之保存文件

      
model = your_nets(num_classes=num_classes, name_classes=name_classes)

save_dict = {
                'state_dict': model.state_dict(),
                'num_classes': model.num_classes,
                'name_classes': model.name_classes
            }

torch.save(save_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

使用 

model = get_nets_class(model_path=model_path)


class get_nets_class(object):
    def __init__(self ,**kwargs):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        load_dict  = torch.load(self.model_path, map_location=device)

        state_dict =load_dict['state_dict']
        num_classes = load_dict['num_classes']
        name_classes = load_dict['name_classes']

        if num_classes is not None and name_classes is not None:
            self.num_classes =num_classes
            self.name_classes = name_classes
            self.net = your_nets(num_classes=self.num_classes,name_classes=name_classe)
            self.net.load_state_dict(state_dict)
        else:
            self.net = your_nets(num_classes=self.num_classes, backbone=self.backbone)
            self.net.load_state_dict(load_dict)
        self.net = self.net.eval()
    
    def predict(self,image,name_classes,object_list):
        #你的预处理操作,没有就忽略
        image_data = preprocess(image)
        with torch.no_grad():
            # 推理
            pr = self.net(images)[0]
            # softmax 得出概率 pr.permute(1, 2, 0), dim=-1为我自己的操作,没有请忽略
            pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
        #你的后处理操作,没有就忽略
        pr = postprocess(pr)
        #这一步与object_list有关 object_list是你想要模型去预测的内容
        # 例如你训练了识别cat、dog、pig、person的类别 那么你想只识别人,那么就object_list=['person'] 
        if object_list is not None:
            model_object_list = [name_classes.index(i) for i in object_list if i in name_classes]
            temp_list = [i for i in range(len(name_classes))]
            remove_list = [i for i in temp_list if i not in model_object_list]
            for i in remove_list:
                pr[pr==i] = 0
        retuen pr

我是觉得已经很详细了,大家要是不懂可以再问,我可以慢慢改进,每个人的写法都不一样 。

欢迎大家点赞加收藏哟~


http://www.kler.cn/news/232874.html

相关文章:

  • python创建udf函数步骤
  • macbook电脑如何永久删除app软件?
  • java基础(2) 面向对象编程-java核心类
  • pytest+allure批量执行测试用例
  • Linux操作系统基础(三):虚拟机与Linux系统安装
  • MATLAB环境下用于提取冲击信号的几种解卷积方法
  • 致我的2023年——个人学年总结
  • 32I2C通信协议
  • android 音频调试技巧
  • 25、数据结构/二叉树相关练习20240207
  • vue项目开发vscode配置
  • 《学成在线》微服务实战项目实操笔记系列(P1~P83)【上】
  • FastAPI使用ORJSONResponse作为默认的响应类型
  • MyBatis之动态代理实现增删改查以及MyBatis-config.xml中读取DB信息文件和SQL中JavaBean别名配置
  • 极值图论基础
  • VScode为什么选择了Electron,而不是QT?
  • Leecode之环形链表
  • c#进程(Process)常用方法
  • Linux运用fork函数创建进程
  • Ubuntu22.04 gnome-builder gnome C 应用程序习练笔记(一)
  • 教你用C++开发 身份证号码日期提取工具
  • 除夕快乐(前端小烟花)
  • 【C++ 二分】电脑游戏
  • 聊聊JIT优化技术
  • Android9~Android13 某些容量SD卡被格式化为内部存储时容量显示错误问题的研究与解决方案
  • 贪心算法入门题(算法村第十七关青铜挑战)
  • Get Ready!这些 ALVA 应用即将上线 Vision Pro!
  • C语言:分支与循环
  • nodejs+vue高校实验室耗材管理系统_m20vy
  • 探索XGBoost:参数调优与模型解释