当前位置: 首页 > article >正文

NCU-机器学习-作业3:基于SVM的手写字识别

任务描述:

手写数字识别是生活中尤其常见的机器学习任务,给出一份手写数字训练数据集,训练一个SVM模型并对测试集进行手写数字识别。

输入数据:

在train/目录下包含多个txt文件,其中每个文件表示一个用01矩阵表示的手写数字,文件名中下划线前面的数字代表手写数字的值(如2_167.txt表示手写数字为2;3_13.txt表示手写数字为3,训练数据集可在教学资料中下载,文件名为svm_train.tar)。

在test/目录下也包含多个txt文件,只不过test文件夹下面的txt文件无法从文件名得知手写数字的值(文件名:0.txt~945.txt),需要根据训练好的模型进行预测。

输出数据:

程序需要生成一个result.csv文件,用于保存程序对test中各个txt文件中手写数字值的预测结果。第一行固定为num,之后每一行为一个数值,代表预测值,表示程序对test中对应txt文件的预测结果。

评价标准:

测试集上的准确率。

输入样例:

00000000000000000011110000000000
00000000000000001111111100000000
00000000001000111111111100000000
00000000011111111111111110000000
00000000111111111111111110000000
00000000111111111111111110000000
00000000111111111111111110000000
00000000111111111111111111000000
00000001111111111101111111000000
00000000111111000000001111000000
00000001111110000000011111000000
00000001111100000000011111000000
00000001111100000000011111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000001111000000
00000001111100000000011111000000
00000000111100000000011111000000
00000000011110000000011111000000
00000000111100000001111110000000
00000000111110000111111000000000
00000000111111111111111000000000
00000000011111111111111000000000
00000000011111111111111000000000
00000000011111111111110000000000
00000000001111111111110000000000
00000000000111111111000000000000
00000000000000111100000000000000

输出样例:

num
0
1
2
3
4

思路代码:

Tips:仅为样例代码,存在可优化部分。

import os

import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import numpy as np


def get_dataset(path, need_label=True):
    dataset, labels = [], []
    filenames = os.listdir(path)

    for filename in filenames:
        if need_label:
            labels.append(filename[0])
        filepath = os.path.join(path, filename)
        dataset.append(np.fromfile(filepath, dtype=np.uint8))

    if need_label:
        return dataset, labels
    return dataset


if __name__ == '__main__':
    X_train, y_train = get_dataset("train")
    X_test = get_dataset("test", need_label=False)

    # 数据标准化
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)  # 使用同一个scaler的transform,避免误差
    y_train = list(y_train)

    model = SVC()
    model.fit(X_train, y_train)

    y_pred_test = model.predict(X_test)

    # 保存预测结果到result.csv
    results = pd.DataFrame({'num': y_pred_test})
    results.to_csv('result.csv', index=False)

答案提交:

提交.py文件即可。


http://www.kler.cn/a/329383.html

相关文章:

  • 学习记录之原型,原型链
  • 无人机高速无刷动力电机核心设计技术
  • 网络安全:信息时代的守护者
  • 农业农村大数据应用场景|珈和科技“数字乡村一张图”解决方案
  • 将IDLE里面python环境pyqt5配置的vscode
  • 数据结构与算法之查找: LeetCode 69. x 的平方根 (Ts版)
  • linux ip命令使用
  • 大数据毕业设计选题推荐-热门微博数据可视化分析系统-Hive-Hadoop-Spark
  • C动态内存管理
  • 【在Linux世界中追寻伟大的One Piece】System V共享内存
  • Spring DI 笔记
  • 使用rust写一个Web服务器——单线程版本
  • 基于SSM+VUE的学生宿舍管理系统
  • 单链表的增删改查(数据结构)
  • OpenAI o1:使用限额提高,o1 模型解析
  • 基于STM32的智能家居语音控制系统:集成LD3320、ESP8266设计流程
  • 【优选算法】(第八篇)
  • 【已解决】【Hadoop】【./bin的使用】bash: ./bin/hdfs: 没有那个文件或目录
  • 基于 Transformer 的中英文翻译项目
  • .NET CORE程序发布IIS后报错误 500.19
  • 网络通信——OSPF协议(基础篇)
  • c++primer第十三章 类继承
  • 第一弹:C++ 的基本知识概述
  • 【深海王国】初中生也能画的电路板?目录合集
  • 巡检机器人室内配电室应用
  • web - RequestResponse