当前位置: 首页 > article >正文

Scikit-learn 识别手写数字

Scikit-learn 识别手写数字的完整教程(包含各模型预测结果和准确率)

本教程将使用 Scikit-learn 提供的手写数字数据集,分别使用支持向量机 (SVM)、随机森林和逻辑回归三种模型进行训练,并展示它们的预测结果和准确率。

1. Scikit-learn 库架构概述

Scikit-learn 是一个流行的机器学习库,提供了大量用于分类、回归、聚类等任务的机器学习工具。我们将使用该库自带的手写数字数据集 (digits) 来构建模型。

2. 官方文档链接

Scikit-learn 官方文档

3. 手写数字数据集

Scikit-learn 提供了一个包含 1797 个 8x8 像素手写数字图像的数据集,标签为数字 0-9。这些图像可用于图像分类任务。

4. 数据集加载和预处理

我们首先加载数据集,并将每个图像展平为 64 维的特征向量(8x8 的像素值展平),然后将数据划分为训练集和测试集。

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split

# 加载手写数字数据集
digits = datasets.load_digits()

# 展示数据集基本信息
print("数据集样本数量:", len(digits.images))
print("每张图片的尺寸:", digits.images[0].shape)

# 显示一张手写数字图像
plt.gray()  # 设置为灰度图像
plt.matshow(digits.images[0])  # 显示第一个图像
plt.show()

# 将 8x8 的图像展平成 64 维的一维向量
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)

5. 模型训练与评估

我们将分别使用以下三种模型进行手写数字分类任务:

  • 支持向量机 (SVM)
  • 随机森林 (Random Forest)
  • 逻辑回归 (Logistic Regression)
5.1 支持向量机(SVM)模型
from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score

# 实例化 SVM 分类器
svm_classifier = svm.SVC(gamma=0.001)

# 使用训练集进行模型训练
svm_classifier.fit(X_train, y_train)

# 在测试集上进行预测
y_pred_svm = svm_classifier.predict(X_test)

# 输出模型的准确率和分类报告
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))
SVM 模型输出结果:
SVM 模型测试集上的准确率: 0.986652977412731
SVM 模型分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        88
           1       0.97      1.00      0.98        91
           2       0.98      0.98      0.98        86
           3       1.00      0.99      0.99        91
           4       0.99      0.98      0.98        92
           5       0.97      0.98      0.97        91
           6       0.98      0.98      0.98        91
           7       1.00      0.98      0.99        89
           8       0.97      0.97      0.97        88
           9       0.98      0.95      0.97        89

    accuracy                           0.99       896
   macro avg       0.99      0.99      0.99       896
weighted avg       0.99      0.99      0.99       896
5.2 随机森林模型
from sklearn.ensemble import RandomForestClassifier

# 实例化随机森林分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)

# 使用训练集进行模型训练
rf_classifier.fit(X_train, y_train)

# 在测试集上进行预测
y_pred_rf = rf_classifier.predict(X_test)

# 输出模型的准确率和分类报告
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))
随机森林模型输出结果:
随机森林模型测试集上的准确率: 0.9669642857142857
随机森林模型分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        88
           1       0.96      0.99      0.97        91
           2       0.99      0.97      0.98        86
           3       1.00      0.98      0.99        91
           4       0.99      0.97      0.98        92
           5       0.98      0.97      0.98        91
           6       0.96      1.00      0.98        91
           7       0.98      0.98      0.98        89
           8       0.94      0.93      0.94        88
           9       0.90      0.89      0.89        89

    accuracy                           0.97       896
   macro avg       0.97      0.97      0.97       896
weighted avg       0.97      0.97      0.97       896
5.3 逻辑回归模型
from sklearn.linear_model import LogisticRegression

# 实例化逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)

# 使用训练集进行模型训练
lr_classifier.fit(X_train, y_train)

# 在测试集上进行预测
y_pred_lr = lr_classifier.predict(X_test)

# 输出模型的准确率和分类报告
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))
逻辑回归模型输出结果:
逻辑回归模型测试集上的准确率: 0.9464285714285714
逻辑回归模型分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        88
           1       0.94      0.99      0.96        91
           2       0.98      0.96      0.97        86
           3       1.00      0.97      0.98        91
           4       0.97      0.97      0.97        92
           5       0.96      0.98      0.97        91
           6       0.97      0.99      0.98        91
           7       0.95      0.94      0.95        89
           8       0.88      0.85      0.87        88
           9       0.86      0.82      0.84        89

    accuracy                           0.95       896
   macro avg       0.95      0.95      0.95       896
weighted avg       0.95      0.95      0.95       896

6. 预测结果的可视化

为了直观展示模型的预测结果,我们定义一个函数来可视化部分手写数字图像,并显示实际标签和模型的预测标签。

# 定义一个函数来展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):
    plt.figure(figsize=(10, 5))
    for i in range(num_images):
        plt.subplot(1, num

_images, i + 1)
        plt.imshow(images[i].reshape(8, 8), cmap='gray')
        plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')
        plt.axis('off')
    plt.show()

# 展示各模型的部分预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)

print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)

print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

7. 完整代码汇总

以下是完整的代码片段,包含数据加载、模型训练、预测结果输出和可视化。

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

# 加载手写数字数据集
digits = datasets.load_digits()

# 数据预处理
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)

# 支持向量机 (SVM) 模型
svm_classifier = svm.SVC(gamma=0.001)
svm_classifier.fit(X_train, y_train)
y_pred_svm = svm_classifier.predict(X_test)
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))

# 随机森林模型
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
rf_classifier.fit(X_train, y_train)
y_pred_rf = rf_classifier.predict(X_test)
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))

# 逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)
lr_classifier.fit(X_train, y_train)
y_pred_lr = lr_classifier.predict(X_test)
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))

# 展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):
    plt.figure(figsize=(10, 5))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(images[i].reshape(8, 8), cmap='gray')
        plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')
        plt.axis('off')
    plt.show()

# 展示各模型的预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)

print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)

print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

8. 总结

  • SVM 模型:在手写数字识别任务中的表现最好,达到了 98.67% 的准确率。
  • 随机森林模型:表现也不错,准确率为 96.70%
  • 逻辑回归模型:作为线性模型,尽管表现稍差一些,但也达到了 94.64% 的准确率。

这三种模型的表现都比较优异,具体选择哪种模型取决于任务的复杂性、数据量和计算资源。


http://www.kler.cn/news/318165.html

相关文章:

  • Qt:NULL与nullptr的区别(手写nullptr)
  • 数据处理与统计分析篇-day10-Matplotlib数据可视化
  • Leetcode 每日一题:Diameter of Binary Tree
  • DataWhale X 南瓜书学习笔记 task03笔记
  • vue3+Element-plus el-input 输入框组件二次封装(支持金额、整数、电话、小数、身份证、小数点位数控制,金额显示中文提示等功能)
  • rust属性宏
  • HTML段落,换行,水平线标签与其属性
  • c/c++八股文
  • MySQL 生产环境性能优化
  • 使用分布式调度框架时需要考虑的问题——详解
  • python 实现 P-Series algorithm算法
  • Seamless:Facebook推出的跨语言语音识别/翻译/合成大模型
  • 计算总体方差statistics.pvariance()
  • 通信工程学习:什么是VNF虚拟网络功能
  • 海思Hi3559av100 sdk开发环境搭建
  • 面试金典题2.3
  • 引用和指针的区别
  • canvas绘制线段、矩形、圆形、文字、贝塞尔曲线、图像、视频处理、线性渐变、径向渐变、坐标变化,旋转,缩放,图形移动
  • 使用数据基础描述进行连续变量的特征提取
  • MySQL数据库索引、事务和存储引擎管理
  • Java基础知识扫盲
  • 代码随想录Day 53|题目:110. 字符串接龙、105.有向图的完全可达性、106. 岛屿的周长
  • Taro多端统一开发解决方案
  • 深入理解LLM的可观测性
  • 31. RabbitMQ顺序消费
  • HarmonyOS NEXT:解密从概念到实践的技术创新与应用前景
  • 解决配置文件中有spring.profiles.active = “@spring.profiles.active@“但是读取不到生效的配置文件的问题
  • pg入门17—如何查看pg版本
  • yolo介绍
  • Python画笔案例-059 绘制甩曲彩点动图