from tensorflow import keras
import matplotlib. pyplot as plt
1.数据预处理
1.1 下载数据集
fashion_mnist = keras. datasets. fashion_mnist
( train_images, train_labels) , ( test_images, test_labels) = fashion_mnist. load_data( )
print ( "train_images shape " , train_images. shape)
print ( "train_labels shape " , train_labels. shape)
print ( "train_labels[0] " , train_labels[ 0 ] )
train_images shape (60000, 28, 28)
train_labels shape (60000,)
train_labels[0] 9
1.2展示数据集的第一张图片
plt. figure( )
plt. imshow( train_images[ 0 ] )
plt. colorbar( )
plt. grid( False )
plt. show
<function matplotlib.pyplot.show(close=None, block=None)>
1.3 展示前25张图片和图片名称
train_images = train_images / 255.0 ;
test_images = test_images / 255.0 ;
plt. figure( figsize= ( 10 , 10 ) )
class_names = [ 'T-shirt/top' , 'Trouser' , 'Pullover' , 'Dress' , 'Coat' ,
'Sandal' , 'Shirt' , 'Sneaker' , 'Bag' , 'Ankle boot' ]
print ( "train_labels " , train_labels[ : 25 ] )
for i in range ( 25 ) :
plt. subplot( 5 , 5 , i+ 1 )
plt. xticks( [ ] )
plt. yticks( [ ] )
plt. grid( False )
plt. imshow( train_images[ i] , cmap= plt. cm. binary)
plt. xlabel( class_names[ train_labels[ i] ] )
plt. show( )
train_labels [9 0 0 3 0 2 7 2 5 5 0 9 5 5 7 9 1 0 6 4 3 1 4 8 4]
2. 模型实现
2.1模型定义
model = keras. Sequential( [
keras. layers. Flatten( input_shape= ( 28 , 28 ) ) ,
keras. layers. Dense( 128 , activation= "relu" ) ,
keras. layers. Dense( 10 , activation= "softmax" )
] )
model. compile ( optimizer= "adam" , loss= "sparse_categorical_crossentropy" , metrics= [ 'accuracy' ] )
model. fit( train_images, train_labels, epochs= 10 )
D:\python\Lib\site-packages\keras\src\layers\reshaping\flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(**kwargs)
Epoch 1/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 977us/step - accuracy: 0.0967 - loss: 2.3028
Epoch 2/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0991 - loss: 2.3027
Epoch 3/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0956 - loss: 2.3028
Epoch 4/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0987 - loss: 2.3027
Epoch 5/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 968us/step - accuracy: 0.0988 - loss: 2.3028
Epoch 6/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1009 - loss: 2.3027
Epoch 7/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0998 - loss: 2.3027
Epoch 8/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0968 - loss: 2.3028
Epoch 9/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1036 - loss: 2.3027
Epoch 10/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 987us/step - accuracy: 0.0973 - loss: 2.3028
<keras.src.callbacks.history.History at 0x20049c207d0>
2.2模型评估测试
test_loss, test_accuracy = model. evaluate( test_images, test_labels, verbose= 2 )
print ( "test_loss " , test_loss)
print ( "test_accuracy" , test_accuracy)
313/313 - 0s - 892us/step - accuracy: 0.1000 - loss: 2.3026
test_loss 2.3026490211486816
test_accuracy 0.10000000149011612
2.3模型预测
predict_result = model. predict( test_images)
print ( "predict_result shape, 样本数,每个样本对每个分类的得分 " , predict_result. shape)
print ( "样本1的每个分类得分, " , predict_result[ 0 ] )
sample_one_result = np. argmax( predict_result[ 0 ] )
print ( "样本1的分类结果%d %s" % ( sample_one_result, class_names[ sample_one_result] ) )
print ( "样本1的真实分类结果%d %s" % ( test_labels[ 0 ] , class_names[ test_labels[ 0 ] ] ) )
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 671us/step
predict_result shape, 样本数,每个样本对每个分类的得分 (10000, 10)
样本1的每个分类得分, [0.10038214 0.09719477 0.10009037 0.10101561 0.09946147 0.10165851
0.10063848 0.09979857 0.09982409 0.09993599]
样本1的分类结果5 Sandal
样本1的真实分类结果9 Ankle boot
2.4 查看指定测试图片的预测结果
def plot_image ( index, predict_classes, true_labels, images) :
true_label = true_labels[ index]
image = images[ index]
plt. grid( False )
plt. xticks( [ ] )
plt. yticks( [ ] )
plt. imshow( image, cmap= plt. cm. binary)
predict_label = np. argmax( predict_classes)
if predict_label == true_label:
color = 'blue'
else :
color = 'red'
plt. xlabel( "{} {:2.0f}%({})" . format (
class_names[ predict_label] ,
100 * np. max ( predict_classes) ,
class_names[ true_label]
) , color= color)
def plot_predict_classes ( i, predict_classes, true_labels) :
true_label = train_labels[ i]
plt. grid( False )
plt. xticks( range ( 10 ) )
plt. yticks( [ ] )
current_plot = plt. bar( range ( 10 ) , predict_classes, color= "#777777" )
plt. ylim( [ 0 , 1 ] )
predict_label = np. argmax( predict_classes)
current_plot[ predict_label] . set_color( "red" )
current_plot[ true_label] . set_color( 'blue' )
i = 0
plt. figure( figsize= ( 6 , 3 ) )
plt. subplot( 1 , 2 , 1 )
plot_image( i, predict_result[ i] , test_labels, test_images)
plt. subplot( 1 , 2 , 2 )
plot_predict_classes( i, predict_result[ i] , test_labels)
plt. show( )
3.保存训练的模型
3.1保存模型
model. save( 'fashion_model.h5' )
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
3.2保存模型到json文件
model_json = model. to_json( )
print ( "model json: " , model_json)
with open ( 'fashion_model_config.json' , 'w' ) as json:
json. write( model_json)
print ( "json from model" )
json_model = keras. models. model_from_json( model_json)
json_model. summary( )
model json: {"module": "keras", "class_name": "Sequential", "config": {"name": "sequential", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "layers": [{"module": "keras.layers", "class_name": "InputLayer", "config": {"batch_shape": [null, 28, 28], "dtype": "float32", "sparse": false, "name": "input_layer"}, "registered_name": null}, {"module": "keras.layers", "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "data_format": "channels_last"}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 128, "activation": "relu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 784]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 128]}}], "build_input_shape": [null, 28, 28]}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}, "compile_config": {"loss": "sparse_categorical_crossentropy", "loss_weights": null, "metrics": ["accuracy"], "weighted_metrics": null, "run_eagerly": false, "steps_per_execution": 1, "jit_compile": false}}
json from model
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ flatten (Flatten) │ (None, 784) │ 0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense) │ (None, 128) │ 100,480 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense) │ (None, 10) │ 1,290 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 203,542 (795.09 KB)
Trainable params: 101,770 (397.54 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 101,772 (397.55 KB)
3.3 保存模型权重到文件
weights = model. get_weights( )
print ( "weight " , weights)
model. save_weights( 'fashion.weights.h5' )
model. load_weights( 'fashion.weights.h5' )
print ( "weight from file" , model. get_weights( ) )
weight [array([[-0.06041221, -0.03045469, -0.06056997, ..., 0.06603239,
-0.06018624, -0.02584767],
[-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,
-0.01347907, 0.00300767],
[ 0.07909157, -0.0689464 , 0.07742291, ..., -0.00037885,
-0.02884226, 0.05017615],
...,
[-0.00013881, 0.0794938 , 0.00120725, ..., -0.00251798,
-0.06103022, -0.05509381],
[ 0.04131137, -0.0285325 , 0.06929631, ..., 0.07573903,
0.02105945, -0.0524031 ],
[ 0.07209501, -0.05137011, -0.07911879, ..., 0.02135488,
0.0670035 , 0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,
-0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,
-0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,
-0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,
-0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,
-0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,
-0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,
-0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,
-0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,
-0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,
-0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,
-0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,
-0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,
-0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,
-0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,
-0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,
-0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,
-0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,
-0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,
-0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,
-0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,
-0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,
-0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,
-0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,
-0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,
-0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623, 0.1688491 , -0.10362723, ..., 0.13207223,
-0.06696159, -0.15404737],
[ 0.08589712, 0.0726881 , -0.03621184, ..., -0.13316402,
-0.11030427, -0.07204279],
[-0.02775251, 0.12212092, 0.12542443, ..., 0.05409406,
0.07715587, 0.12737972],
...,
[-0.12100082, -0.0844327 , 0.03725254, ..., 0.04297927,
-0.06126365, -0.04448495],
[ 0.00898614, 0.11527378, -0.10356722, ..., -0.09458876,
-0.02348839, 0.11287841],
[-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,
0.1703024 , 0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637, 0.0076253 , -0.00787955,
0.0139694 , 0.0038848 , -0.004496 , -0.00424037, -0.00311988],
dtype=float32)]
weight from file [array([[-0.06041221, -0.03045469, -0.06056997, ..., 0.06603239,
-0.06018624, -0.02584767],
[-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,
-0.01347907, 0.00300767],
[ 0.07909157, -0.0689464 , 0.07742291, ..., -0.00037885,
-0.02884226, 0.05017615],
...,
[-0.00013881, 0.0794938 , 0.00120725, ..., -0.00251798,
-0.06103022, -0.05509381],
[ 0.04131137, -0.0285325 , 0.06929631, ..., 0.07573903,
0.02105945, -0.0524031 ],
[ 0.07209501, -0.05137011, -0.07911879, ..., 0.02135488,
0.0670035 , 0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,
-0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,
-0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,
-0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,
-0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,
-0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,
-0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,
-0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,
-0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,
-0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,
-0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,
-0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,
-0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,
-0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,
-0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,
-0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,
-0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,
-0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,
-0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,
-0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,
-0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,
-0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,
-0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,
-0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,
-0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,
-0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623, 0.1688491 , -0.10362723, ..., 0.13207223,
-0.06696159, -0.15404737],
[ 0.08589712, 0.0726881 , -0.03621184, ..., -0.13316402,
-0.11030427, -0.07204279],
[-0.02775251, 0.12212092, 0.12542443, ..., 0.05409406,
0.07715587, 0.12737972],
...,
[-0.12100082, -0.0844327 , 0.03725254, ..., 0.04297927,
-0.06126365, -0.04448495],
[ 0.00898614, 0.11527378, -0.10356722, ..., -0.09458876,
-0.02348839, 0.11287841],
[-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,
0.1703024 , 0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637, 0.0076253 , -0.00787955,
0.0139694 , 0.0038848 , -0.004496 , -0.00424037, -0.00311988],
dtype=float32)]