HIT 模式识别 手写汉字分类 Python实现
训练集数据 TrainSamples-400.csv,含 100 个不同汉字,每个汉字 400 个实例,每个实例均为 64*64 的二值图像;
训练集标注TrainSamples-400.csv,为 40000 个 0 到 99 间的整数,表示训练集中每个实例所属汉字类别;
测试集数据 TestSamples-300.csv,为 30000 个实例,每个实例格式同训练集。
要求标注测试集,输出 Result.csv。
import numpy as np
import pandas as pd
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers
def train():
data = pd.read_csv("TrainSamples-400.csv", header=None)
train_image = data.to_numpy()
data = pd.read_csv("TrainLabels-400.csv", header=None)
train_label = data.to_numpy()
train_label = to_categorical(train_label)
network = models.Sequential()
network.add(layers.Input(shape = (64, 64, 1)))
network.add(layers.Conv2D(64, (5, 5), activation = 'relu'))
network.add(layers.MaxPooling2D((2, 2)))
network.add(layers.Conv2D(96, (3, 3), activation = 'relu'))
network.add(layers.MaxPooling2D((2, 2)))
network.add(layers.Conv2D(48, (3, 3), activation = 'relu'))
network.add(layers.Flatten())
network.add(layers.Dense(768, activation = 'relu'))
network.add(layers.Dense(100, activation = 'softmax'))
network.summary()
network.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
network.fit(train_image.reshape(40000, 64, 64, 1), train_label, epochs = 5, batch_size = 64, validation_split = 0.1, validation_freq = 1)
network.save('saved_model/my_model')
def test():
data = pd.read_csv("TestSamples-300.csv", header = None)
test_image = data.to_numpy()
network = models.load_model('saved_model/my_model')
network.summary()
test_label = network.predict(test_image.reshape(30000, 64, 64, 1))
test_label = np.array([np.argmax(i) for i in test_label])
pd.DataFrame(test_label).to_csv('Result.csv', header = None, index = False)
if __name__ == '__main__':
train()
test()