机器学习(模型的保存和加载)
在机器学习中,模型训练通常需要耗费大量的时间和计算资源。为了避免重复训练,同时方便在不同环境中使用已经训练好的模型,我们需要对模型进行保存和加载。以下将介绍几种常见的模型保存与加载的方法,以 scikit - learn
和 TensorFlow
模型为例。
1. 使用 joblib
保存和加载 scikit - learn
模型
joblib
是 scikit - learn
推荐的用于保存和加载模型的工具,它在处理大型 numpy
数组时比 Python 内置的 pickle
模块更高效。
保存模型
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from joblib import dump
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)
# 保存模型
dump(model, 'knn_model.joblib')
加载模型
from joblib import load
# 加载模型
loaded_model = load('knn_model.joblib')
# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)
2. 使用 pickle
保存和加载 scikit - learn
模型
pickle
是 Python 内置的用于对象序列化和反序列化的模块,也可以用于保存和加载 scikit - learn
模型。
保存模型
import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)
# 保存模型
with open('knn_model.pkl', 'wb') as f:
pickle.dump(model, f)
加载模型
import pickle
# 加载模型
with open('knn_model.pkl', 'rb') as f:
loaded_model = pickle.load(f)
# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)
3. 使用 TensorFlow
保存和加载深度学习模型
保存模型
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train = x_train / 255.0
x_test = x_test / 255.0
# 构建模型
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 保存模型
model.save('mnist_model.h5')
加载模型
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载数据集
(_, _), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_test = x_test / 255.0
# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')
# 使用加载的模型进行预测
predictions = loaded_model.predict(x_test)
print(predictions)
总结
- 对于
scikit - learn
模型,推荐使用joblib
进行保存和加载,尤其是处理大型numpy
数组时。 pickle
是 Python 内置的通用序列化工具,也可以用于保存和加载scikit - learn
模型。- 对于
TensorFlow
深度学习模型,可以使用model.save()
方法保存为.h5
文件,使用tf.keras.models.load_model()
方法加载模型