当前位置: 首页 > article >正文

机器学习(模型的保存和加载)

在机器学习中,模型训练通常需要耗费大量的时间和计算资源。为了避免重复训练,同时方便在不同环境中使用已经训练好的模型,我们需要对模型进行保存和加载。以下将介绍几种常见的模型保存与加载的方法,以 scikit - learn 和 TensorFlow 模型为例。

1. 使用 joblib 保存和加载 scikit - learn 模型

joblib 是 scikit - learn 推荐的用于保存和加载模型的工具,它在处理大型 numpy 数组时比 Python 内置的 pickle 模块更高效。

保存模型 

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from joblib import dump

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)

# 保存模型
dump(model, 'knn_model.joblib')

 加载模型

from joblib import load

# 加载模型
loaded_model = load('knn_model.joblib')

# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

2. 使用 pickle 保存和加载 scikit - learn 模型

pickle 是 Python 内置的用于对象序列化和反序列化的模块,也可以用于保存和加载 scikit - learn 模型。

保存模型

import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)

# 保存模型
with open('knn_model.pkl', 'wb') as f:
    pickle.dump(model, f)

 加载模型

import pickle

# 加载模型
with open('knn_model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

3. 使用 TensorFlow 保存和加载深度学习模型

保存模型

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_train = x_train / 255.0
x_test = x_test / 255.0

# 构建模型
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 保存模型
model.save('mnist_model.h5')

加载模型

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# 加载数据集
(_, _), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_test = x_test / 255.0

# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')

# 使用加载的模型进行预测
predictions = loaded_model.predict(x_test)
print(predictions)

总结

  • 对于 scikit - learn 模型,推荐使用 joblib 进行保存和加载,尤其是处理大型 numpy 数组时。
  • pickle 是 Python 内置的通用序列化工具,也可以用于保存和加载 scikit - learn 模型。
  • 对于 TensorFlow 深度学习模型,可以使用 model.save() 方法保存为 .h5 文件,使用 tf.keras.models.load_model() 方法加载模型

http://www.kler.cn/a/565363.html

相关文章:

  • 【版本控制安全简报】Perforce Helix Core安全更新:漏洞修复与国内用户支持
  • nginx 动态计算拦截非法访问ip
  • 【Linux】ubuntu server扩容硬盘
  • 树莓百度百科更新!宜宾园区业务再添新篇
  • 【Python爬虫(96)】从0到1:打造爬虫驱动的数据分析平台
  • 联想 SR590 服务器 530-8i RAID 控制器更换损坏的硬盘
  • 以太坊测试网
  • YOLOv8+QT搭建目标检测项目
  • ruoyi vue el-elementui el-tree 自适应宽度向左浮动
  • 从扫描到建模:盎锐UCL360PRO如何实现隧道的数字化重建运维
  • MinIO整合SpringBoot实现文件上传、下载
  • 服务器广播需要广播的服务器数量
  • DOM 事件 HTML 标签属性速查手册
  • 【无监督学习】主成分分析步骤及matlab实现
  • MySQL数据库入门到大蛇尚硅谷宋红康老师笔记 高级篇 part 4
  • 神经网络 - 激活函数(ReLU 函数)
  • 第12章_管理令牌和会话
  • AORO M6北斗短报文终端:将“太空黑科技”转化为安全保障
  • 提示学习(Prompting)
  • 使用ssh客户端完成Linux远程登录