当前位置: 首页 > article >正文

神经网络_使用tensorflow对fashion mnist衣服数据集分类

from tensorflow import keras 
import matplotlib.pyplot as plt

1.数据预处理

1.1 下载数据集

fashion_mnist = keras.datasets.fashion_mnist
#下载 fashion mnist数据集
(train_images, train_labels),(test_images, test_labels) = fashion_mnist.load_data()

print("train_images shape ", train_images.shape)
print("train_labels shape ", train_labels.shape)
print("train_labels[0] ", train_labels[0])
train_images shape  (60000, 28, 28)
train_labels shape  (60000,)
train_labels[0]  9

1.2展示数据集的第一张图片

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show
<function matplotlib.pyplot.show(close=None, block=None)>

在这里插入图片描述

1.3 展示前25张图片和图片名称

train_images = train_images / 255.0;
test_images = test_images / 255.0;

plt.figure(figsize=(10, 10))
class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print("train_labels ", train_labels[:25])
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()
    
train_labels  [9 0 0 3 0 2 7 2 5 5 0 9 5 5 7 9 1 0 6 4 3 1 4 8 4]

在这里插入图片描述

2. 模型实现

2.1模型定义

#定义模型
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28,28)),
    keras.layers.Dense(128, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
])
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)

D:\python\Lib\site-packages\keras\src\layers\reshaping\flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)


Epoch 1/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 977us/step - accuracy: 0.0967 - loss: 2.3028
Epoch 2/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0991 - loss: 2.3027
Epoch 3/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0956 - loss: 2.3028
Epoch 4/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0987 - loss: 2.3027
Epoch 5/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 968us/step - accuracy: 0.0988 - loss: 2.3028
Epoch 6/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1009 - loss: 2.3027
Epoch 7/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0998 - loss: 2.3027
Epoch 8/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0968 - loss: 2.3028
Epoch 9/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1036 - loss: 2.3027
Epoch 10/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 987us/step - accuracy: 0.0973 - loss: 2.3028





<keras.src.callbacks.history.History at 0x20049c207d0>

2.2模型评估测试

#评估测试
test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=2)
print("test_loss ", test_loss)
print("test_accuracy", test_accuracy)
313/313 - 0s - 892us/step - accuracy: 0.1000 - loss: 2.3026
test_loss  2.3026490211486816
test_accuracy 0.10000000149011612

2.3模型预测

predict_result = model.predict(test_images)
print("predict_result shape, 样本数,每个样本对每个分类的得分 ", predict_result.shape)
print("样本1的每个分类得分, ", predict_result[0])
sample_one_result = np.argmax(predict_result[0])
print("样本1的分类结果%d %s"%(sample_one_result,class_names[sample_one_result]))
print("样本1的真实分类结果%d %s"%(test_labels[0],class_names[test_labels[0]]))
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 671us/step
predict_result shape, 样本数,每个样本对每个分类的得分  (10000, 10)
样本1的每个分类得分,  [0.10038214 0.09719477 0.10009037 0.10101561 0.09946147 0.10165851
 0.10063848 0.09979857 0.09982409 0.09993599]
样本1的分类结果5 Sandal
样本1的真实分类结果9 Ankle boot

2.4 查看指定测试图片的预测结果

#画指定索引位置的图
def plot_image(index, predict_classes, true_labels, images):
    true_label = true_labels[index]
    image = images[index]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(image, cmap=plt.cm.binary)
    predict_label = np.argmax(predict_classes)
    if predict_label == true_label:
        color = 'blue'
    else:
        color = 'red'
    plt.xlabel("{} {:2.0f}%({})".format(
                class_names[predict_label],
                100 * np.max(predict_classes),
                class_names[true_label]
                ), color=color)
# 画指定样本的对所有分类的预测得分    
def plot_predict_classes(i, predict_classes, true_labels):
    true_label = train_labels[i]
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    current_plot = plt.bar(range(10), predict_classes, color="#777777")
    plt.ylim([0,1])
    predict_label = np.argmax(predict_classes)
    current_plot[predict_label].set_color("red")
    current_plot[true_label].set_color('blue')

 # 画第一个样本的图,和对每个分类的得分
i = 0
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predict_result[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_predict_classes(i, predict_result[i], test_labels)
plt.show()

在这里插入图片描述

3.保存训练的模型

3.1保存模型

# 保存模型
model.save('fashion_model.h5')
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

3.2保存模型到json文件

#查看模型
model_json = model.to_json()
print("model json: ", model_json)

#保存json到文件中
with open('fashion_model_config.json', 'w') as json:
    json.write(model_json)

#从json文件中加载模型
print("json from model")
json_model = keras.models.model_from_json(model_json)
json_model.summary()
model json:  {"module": "keras", "class_name": "Sequential", "config": {"name": "sequential", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "layers": [{"module": "keras.layers", "class_name": "InputLayer", "config": {"batch_shape": [null, 28, 28], "dtype": "float32", "sparse": false, "name": "input_layer"}, "registered_name": null}, {"module": "keras.layers", "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "data_format": "channels_last"}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 128, "activation": "relu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 784]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 128]}}], "build_input_shape": [null, 28, 28]}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}, "compile_config": {"loss": "sparse_categorical_crossentropy", "loss_weights": null, "metrics": ["accuracy"], "weighted_metrics": null, "run_eagerly": false, "steps_per_execution": 1, "jit_compile": false}}
json from model
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ flatten (Flatten)                    │ (None, 784)                 │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ (None, 128)                 │         100,480 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense)                      │ (None, 10)                  │           1,290 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 203,542 (795.09 KB)
 Trainable params: 101,770 (397.54 KB)
 Non-trainable params: 0 (0.00 B)
 Optimizer params: 101,772 (397.55 KB)

3.3 保存模型权重到文件

weights = model.get_weights()
print("weight ", weights)
model.save_weights('fashion.weights.h5')
model.load_weights('fashion.weights.h5')
print("weight from file", model.get_weights())
weight  [array([[-0.06041221, -0.03045469, -0.06056997, ...,  0.06603239,
        -0.06018624, -0.02584767],
       [-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,
        -0.01347907,  0.00300767],
       [ 0.07909157, -0.0689464 ,  0.07742291, ..., -0.00037885,
        -0.02884226,  0.05017615],
       ...,
       [-0.00013881,  0.0794938 ,  0.00120725, ..., -0.00251798,
        -0.06103022, -0.05509381],
       [ 0.04131137, -0.0285325 ,  0.06929631, ...,  0.07573903,
         0.02105945, -0.0524031 ],
       [ 0.07209501, -0.05137011, -0.07911879, ...,  0.02135488,
         0.0670035 ,  0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,
       -0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,
       -0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,
       -0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,
       -0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,
       -0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,
       -0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,
       -0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,
       -0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,
       -0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,
       -0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,
       -0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,
       -0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,
       -0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,
       -0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,
       -0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,
       -0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,
       -0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,
       -0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,
       -0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,
       -0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,
       -0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,
       -0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,
       -0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,
       -0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,
       -0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623,  0.1688491 , -0.10362723, ...,  0.13207223,
        -0.06696159, -0.15404737],
       [ 0.08589712,  0.0726881 , -0.03621184, ..., -0.13316402,
        -0.11030427, -0.07204279],
       [-0.02775251,  0.12212092,  0.12542443, ...,  0.05409406,
         0.07715587,  0.12737972],
       ...,
       [-0.12100082, -0.0844327 ,  0.03725254, ...,  0.04297927,
        -0.06126365, -0.04448495],
       [ 0.00898614,  0.11527378, -0.10356722, ..., -0.09458876,
        -0.02348839,  0.11287841],
       [-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,
         0.1703024 ,  0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637,  0.0076253 , -0.00787955,
        0.0139694 ,  0.0038848 , -0.004496  , -0.00424037, -0.00311988],
      dtype=float32)]
weight from file [array([[-0.06041221, -0.03045469, -0.06056997, ...,  0.06603239,
        -0.06018624, -0.02584767],
       [-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,
        -0.01347907,  0.00300767],
       [ 0.07909157, -0.0689464 ,  0.07742291, ..., -0.00037885,
        -0.02884226,  0.05017615],
       ...,
       [-0.00013881,  0.0794938 ,  0.00120725, ..., -0.00251798,
        -0.06103022, -0.05509381],
       [ 0.04131137, -0.0285325 ,  0.06929631, ...,  0.07573903,
         0.02105945, -0.0524031 ],
       [ 0.07209501, -0.05137011, -0.07911879, ...,  0.02135488,
         0.0670035 ,  0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,
       -0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,
       -0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,
       -0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,
       -0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,
       -0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,
       -0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,
       -0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,
       -0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,
       -0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,
       -0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,
       -0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,
       -0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,
       -0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,
       -0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,
       -0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,
       -0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,
       -0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,
       -0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,
       -0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,
       -0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,
       -0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,
       -0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,
       -0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,
       -0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,
       -0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623,  0.1688491 , -0.10362723, ...,  0.13207223,
        -0.06696159, -0.15404737],
       [ 0.08589712,  0.0726881 , -0.03621184, ..., -0.13316402,
        -0.11030427, -0.07204279],
       [-0.02775251,  0.12212092,  0.12542443, ...,  0.05409406,
         0.07715587,  0.12737972],
       ...,
       [-0.12100082, -0.0844327 ,  0.03725254, ...,  0.04297927,
        -0.06126365, -0.04448495],
       [ 0.00898614,  0.11527378, -0.10356722, ..., -0.09458876,
        -0.02348839,  0.11287841],
       [-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,
         0.1703024 ,  0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637,  0.0076253 , -0.00787955,
        0.0139694 ,  0.0038848 , -0.004496  , -0.00424037, -0.00311988],
      dtype=float32)]

http://www.kler.cn/a/307613.html

相关文章:

  • Python学习从0到1 day29 Python 高阶技巧 ⑦ 正则表达式
  • 《C语言程序设计现代方法》note-4 基本类型 强制类型转换 类型定义
  • 直接映射缓存配置
  • 深度学习的多主机多GPU协同训练
  • 《动手学深度学习》中d2l库的安装以及问题解决
  • JSON.stringify的应用说明
  • uniapp js修改数组某个下标以外的所有值
  • 2024.09.08 校招 实习 内推 面经
  • python Open3D 验证安装崩溃
  • 论文内容分类与检测系统源码分享
  • String 72变 ---------各种字符串处理方法
  • WSL挂载U盘或移动硬盘
  • 一起对话式学习-机器学习02——机器学习方法三要素
  • Apache-wed服务器环境的安装
  • 智能工厂的设计软件 单一面问题分析方法的“独角兽”程序
  • JVM面试真题总结(七)
  • 总结对象相关知识
  • Go语言并发编程之select语句详解
  • 【相机方案(2)】V4L2 支持相机图像直接进入GPU内存吗?DeepStream 确实可以将图像数据高效地放入GPU内存进行处理!
  • 后端开发刷题 | 打家劫舍
  • gin基本使用
  • 30款免费好用的工具,打工人必备!
  • 基于Keil软件实现实时时钟(江协科技HAL库)
  • Java-数据结构-二叉树-基础 (o゚▽゚)o
  • 代码随想录训练营Day3 | 链表理论基础 | 203.移除链表元素 | 707.设计链表 | 206.反转链表
  • Flink学习2