当前位置: 首页 > article >正文

HIT 模式识别 手写汉字分类 Python实现

训练集数据 TrainSamples-400.csv,含 100 个不同汉字,每个汉字 400 个实例,每个实例均为 64*64 的二值图像;
训练集标注TrainSamples-400.csv,为 40000 个 0 到 99 间的整数,表示训练集中每个实例所属汉字类别;
测试集数据 TestSamples-300.csv,为 30000 个实例,每个实例格式同训练集。
要求标注测试集,输出 Result.csv。

import numpy as np
import pandas as pd
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers

def train():
    data = pd.read_csv("TrainSamples-400.csv", header=None)
    train_image = data.to_numpy()
    data = pd.read_csv("TrainLabels-400.csv", header=None)
    train_label = data.to_numpy()
    train_label = to_categorical(train_label)
    network = models.Sequential()
    network.add(layers.Input(shape = (64, 64, 1)))
    network.add(layers.Conv2D(64, (5, 5), activation = 'relu'))
    network.add(layers.MaxPooling2D((2, 2)))
    network.add(layers.Conv2D(96, (3, 3), activation = 'relu'))
    network.add(layers.MaxPooling2D((2, 2)))
    network.add(layers.Conv2D(48, (3, 3), activation = 'relu'))
    network.add(layers.Flatten())
    network.add(layers.Dense(768, activation = 'relu'))
    network.add(layers.Dense(100, activation = 'softmax'))
    network.summary()
    network.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
    network.fit(train_image.reshape(40000, 64, 64, 1), train_label, epochs = 5, batch_size = 64, validation_split = 0.1, validation_freq = 1)
    network.save('saved_model/my_model')
    
def test():
    data = pd.read_csv("TestSamples-300.csv", header = None)
    test_image = data.to_numpy()
    network = models.load_model('saved_model/my_model')
    network.summary()
    test_label = network.predict(test_image.reshape(30000, 64, 64, 1))
    test_label = np.array([np.argmax(i) for i in test_label])
    pd.DataFrame(test_label).to_csv('Result.csv', header = None, index = False)

if __name__ == '__main__':
    train()
    test()

http://www.kler.cn/a/133648.html

相关文章:

  • 从对等通信到万维网:通信模型变迁与拥塞求解
  • 重生之我在异世界学编程之C语言:深入指针篇(上)
  • C++priority_queue模拟实现
  • OpenVela 各模块之间的交互方式和数据流
  • 纯前端实现表格中的数据导出功能-使用xlsx和file-saver
  • flume系列之:flume落cos
  • 038、语义分割
  • C++--哈希表--散列--冲突--哈希闭散列模拟实现
  • LintCode 1394 · Goat Latin (字符串处理题)
  • NET8 ORM 使用AOT SqlSugar
  • RabbitMQ-高级篇-黑马程序员
  • jsp中使用PDF.js实现pdf文件的预览
  • php mysql 如何处理查询中存在正则特殊字符的查询
  • Element-UI el-select下拉框多选实现全选
  • 生命在于学习——主板跳线的学习
  • OPPO发布AndesGPT大模型;Emu Video和Emu Edit的新突破
  • 正则表达式去掉代码末尾的数字
  • Hive默认分割符、存储格式与数据压缩
  • Linux环境的Windows子系统
  • C语言基础篇4:变量、存储、库函数
  • 【Seata源码学习 】篇三 seata客户端全局事务开启、提交与回滚
  • 【hive-解决】HiveAccessControlException Permission denied: CREATEFUNCTION
  • Linux 网络:PMTUD 简介
  • 麦克风阵列入门
  • Linux - 物理内存管理 - memmap
  • java游戏制作-拼图游戏