第R3周:RNN-心脏病预测(TensorFlow版)
>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
🍺 要求:
- 找到并处理第8周的程序问题(本文给出了答案)
- 了解循环神经网络(RNN)的构建过程。
- 测试集accuracy到达87%。
🍻 拔高(可选):
- 测试集accuracy到达89%
往期文章可查阅: 深度学习总结
🚀我的环境:
- 语言环境:Python3.11.7
- 编译器:jupyter notebook
- 深度学习框架:TensorFlow2.13.0
代码流程图如下所示:
一、RNN是什么
传统神经网络的结构比较简单:输入层 – 隐藏层 – 输出层
RNN 跟传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示:
这里用一个具体的案例来看看 RNN 是如何工作的:用户说了一句“what time is it?”,我们的神经网络会先将这句话分为五个基本单元(四个单词+一个问号)
然后,按照顺序将五个基本单元输入RNN网络,先将 “what”作为RNN的输入,得到输出 01
随后,按照顺序将“time”输入到RNN网络,得到输出02。
这个过程我们可以看到,输入 “time” 的时候,前面“what” 的输出也会对02的输出产生了影响(隐藏层中有一半是黑色的)。
以此类推,我们可以看到,前面所有的输入产生的结果都对后续的输出产生了影响(可以看到圆形中包含了前面所有的颜色)。
当神经网络判断意图的时候,只需要最后一层的输出 05,如下图所示:
二、前期准备
1. 设置GPU
import tensorflow as tf
gpus=tf.config.list_physical_devices("GPU")
if gpus:
gpu0=gpus[0] #如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0,True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0],"GPU")
gpus
运行结果:
[]
因为我的TensorFlow没有安装GPU版本,故显示此。
2. 导入数据
数据介绍:
●age:1) 年龄
●sex:2) 性别
●cp:3) 胸痛类型 (4 values)
●trestbps:4) 静息血压
●chol:5) 血清胆甾醇 (mg/dl
●fbs:6) 空腹血糖 > 120 mg/dl
●restecg:7) 静息心电图结果 (值 0,1 ,2)
●thalach:8) 达到的最大心率
●exang:9) 运动诱发的心绞痛
●oldpeak:10) 相对于静止状态,运动引起的ST段压低
●slope:11) 运动峰值 ST 段的斜率
●ca:12) 荧光透视着色的主要血管数量 (0-3)
●thal:13) 0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
●target:14) 0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大
import pandas as pd
import numpy as np
df=pd.read_csv("D:\THE MNIST DATABASE\RNN\R3\heart.csv")
df
运行结果:
3. 检查数据
# 检查是否有空值
df.isnull().sum()
运行结果:
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64
三、数据预处理
1. 划分训练集与测试集
测试集与验证集的关系:
(1)验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
(2)但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
(3)我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集。
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
x=df.iloc[:,:-1]
y=df.iloc[:,-1]
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.1,random_state=1)
x_train.shape,y_train.shape
运行结果:
((272, 13), (272,))
2. 标准化
# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc=StandardScaler()
x_train=sc.fit_transform(x_train)
x_test=sc.transform(x_test)
x_train=x_train.reshape(x_train.shape[0],x_train.shape[1],1)
x_test=x_test.reshape(x_test.shape[0],x_test.shape[1],1)
四、构建RNN模型
函数原型:
tf.keras.layers.SimpleRNN(units,activation=‘tanh’,use_bias=True,kernel_initializer=‘glorot_uniform’,
recurrent_initializer=‘orthogonal’,bias_initializer=‘zeros’,kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,recurrent_constraint=None,
bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,
go_backwards=False,stateful=False,unroll=False,**kwargs)
关键参数说明:
●units: 正整数,输出空间的维度。
●activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。
●use_bias: 布尔值,该层是否使用偏置向量。
●kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。
●recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。
●bias_initializer:偏置向量的初始化器 (详见initializers)。
●dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
model=Sequential()
model.add(SimpleRNN(200,input_shape=(13,1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()
运行结果:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 200) 40400
dense (Dense) (None, 100) 20100
dense_1 (Dense) (None, 1) 101
=================================================================
Total params: 60601 (236.72 KB)
Trainable params: 60601 (236.72 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
五、编译模型
opt=tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy',
optimizer=opt,
metrics="accuracy")
六、训练模型
epochs=100
history=model.fit(x_train,y_train,
epochs=epochs,
batch_size=128,
validation_data=(x_test,y_test),
verbose=1)
运行结果:
Epoch 1/100
3/3 [==============================] - 1s 140ms/step - loss: 0.6778 - accuracy: 0.7206 - val_loss: 0.6456 - val_accuracy: 0.8710
Epoch 2/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6669 - accuracy: 0.7721 - val_loss: 0.6315 - val_accuracy: 0.8710
Epoch 3/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6576 - accuracy: 0.7794 - val_loss: 0.6184 - val_accuracy: 0.8387
Epoch 4/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6489 - accuracy: 0.7794 - val_loss: 0.6051 - val_accuracy: 0.8710
Epoch 5/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6399 - accuracy: 0.7721 - val_loss: 0.5921 - val_accuracy: 0.8710
Epoch 6/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6309 - accuracy: 0.7757 - val_loss: 0.5785 - val_accuracy: 0.8710
Epoch 7/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6221 - accuracy: 0.7757 - val_loss: 0.5644 - val_accuracy: 0.8710
Epoch 8/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6133 - accuracy: 0.7757 - val_loss: 0.5502 - val_accuracy: 0.8710
Epoch 9/100
3/3 [==============================] - 0s 15ms/step - loss: 0.6037 - accuracy: 0.7757 - val_loss: 0.5355 - val_accuracy: 0.8710
Epoch 10/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5932 - accuracy: 0.7794 - val_loss: 0.5197 - val_accuracy: 0.8387
Epoch 11/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5832 - accuracy: 0.7831 - val_loss: 0.5031 - val_accuracy: 0.8387
Epoch 12/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5717 - accuracy: 0.7941 - val_loss: 0.4862 - val_accuracy: 0.8387
Epoch 13/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5595 - accuracy: 0.7941 - val_loss: 0.4670 - val_accuracy: 0.8387
Epoch 14/100
3/3 [==============================] - 0s 14ms/step - loss: 0.5462 - accuracy: 0.8015 - val_loss: 0.4456 - val_accuracy: 0.8710
Epoch 15/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5318 - accuracy: 0.8088 - val_loss: 0.4226 - val_accuracy: 0.8710
Epoch 16/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5166 - accuracy: 0.8015 - val_loss: 0.3984 - val_accuracy: 0.8387
Epoch 17/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5000 - accuracy: 0.8015 - val_loss: 0.3757 - val_accuracy: 0.8387
Epoch 18/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4845 - accuracy: 0.8088 - val_loss: 0.3550 - val_accuracy: 0.8387
Epoch 19/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4694 - accuracy: 0.8088 - val_loss: 0.3355 - val_accuracy: 0.8710
Epoch 20/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4545 - accuracy: 0.8015 - val_loss: 0.3177 - val_accuracy: 0.8710
Epoch 21/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4425 - accuracy: 0.8015 - val_loss: 0.3035 - val_accuracy: 0.8710
Epoch 22/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4350 - accuracy: 0.8015 - val_loss: 0.2928 - val_accuracy: 0.8710
Epoch 23/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4264 - accuracy: 0.8015 - val_loss: 0.2856 - val_accuracy: 0.8710
Epoch 24/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4199 - accuracy: 0.7978 - val_loss: 0.2840 - val_accuracy: 0.9032
Epoch 25/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4175 - accuracy: 0.8088 - val_loss: 0.2795 - val_accuracy: 0.9032
Epoch 26/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4127 - accuracy: 0.8051 - val_loss: 0.2726 - val_accuracy: 0.9032
Epoch 27/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4080 - accuracy: 0.8088 - val_loss: 0.2675 - val_accuracy: 0.8710
Epoch 28/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4088 - accuracy: 0.8125 - val_loss: 0.2663 - val_accuracy: 0.9032
Epoch 29/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4026 - accuracy: 0.8235 - val_loss: 0.2671 - val_accuracy: 0.9032
Epoch 30/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3955 - accuracy: 0.8125 - val_loss: 0.2701 - val_accuracy: 0.9032
Epoch 31/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3929 - accuracy: 0.8162 - val_loss: 0.2703 - val_accuracy: 0.9032
Epoch 32/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3905 - accuracy: 0.8199 - val_loss: 0.2686 - val_accuracy: 0.9032
Epoch 33/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3877 - accuracy: 0.8199 - val_loss: 0.2631 - val_accuracy: 0.9032
Epoch 34/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3832 - accuracy: 0.8235 - val_loss: 0.2568 - val_accuracy: 0.9032
Epoch 35/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3843 - accuracy: 0.8162 - val_loss: 0.2560 - val_accuracy: 0.9032
Epoch 36/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3809 - accuracy: 0.8199 - val_loss: 0.2577 - val_accuracy: 0.9032
Epoch 37/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3752 - accuracy: 0.8199 - val_loss: 0.2602 - val_accuracy: 0.9032
Epoch 38/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3718 - accuracy: 0.8309 - val_loss: 0.2629 - val_accuracy: 0.9032
Epoch 39/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3694 - accuracy: 0.8235 - val_loss: 0.2622 - val_accuracy: 0.9032
Epoch 40/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3666 - accuracy: 0.8272 - val_loss: 0.2601 - val_accuracy: 0.9032
Epoch 41/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3655 - accuracy: 0.8309 - val_loss: 0.2594 - val_accuracy: 0.9032
Epoch 42/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3643 - accuracy: 0.8346 - val_loss: 0.2587 - val_accuracy: 0.9032
Epoch 43/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3600 - accuracy: 0.8382 - val_loss: 0.2610 - val_accuracy: 0.9032
Epoch 44/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3568 - accuracy: 0.8382 - val_loss: 0.2637 - val_accuracy: 0.9032
Epoch 45/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3560 - accuracy: 0.8346 - val_loss: 0.2608 - val_accuracy: 0.9032
Epoch 46/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3527 - accuracy: 0.8382 - val_loss: 0.2563 - val_accuracy: 0.9032
Epoch 47/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3506 - accuracy: 0.8382 - val_loss: 0.2541 - val_accuracy: 0.9032
Epoch 48/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3482 - accuracy: 0.8419 - val_loss: 0.2542 - val_accuracy: 0.9032
Epoch 49/100
3/3 [==============================] - 0s 14ms/step - loss: 0.3457 - accuracy: 0.8419 - val_loss: 0.2560 - val_accuracy: 0.9032
Epoch 50/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3418 - accuracy: 0.8456 - val_loss: 0.2558 - val_accuracy: 0.9032
Epoch 51/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3401 - accuracy: 0.8529 - val_loss: 0.2554 - val_accuracy: 0.9032
Epoch 52/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3381 - accuracy: 0.8529 - val_loss: 0.2577 - val_accuracy: 0.9032
Epoch 53/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3354 - accuracy: 0.8529 - val_loss: 0.2608 - val_accuracy: 0.9032
Epoch 54/100
3/3 [==============================] - 0s 13ms/step - loss: 0.3337 - accuracy: 0.8603 - val_loss: 0.2611 - val_accuracy: 0.9032
Epoch 55/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3318 - accuracy: 0.8603 - val_loss: 0.2628 - val_accuracy: 0.9032
Epoch 56/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3302 - accuracy: 0.8640 - val_loss: 0.2666 - val_accuracy: 0.9032
Epoch 57/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3292 - accuracy: 0.8603 - val_loss: 0.2669 - val_accuracy: 0.9032
Epoch 58/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3261 - accuracy: 0.8640 - val_loss: 0.2655 - val_accuracy: 0.9032
Epoch 59/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3231 - accuracy: 0.8640 - val_loss: 0.2669 - val_accuracy: 0.9032
Epoch 60/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3219 - accuracy: 0.8640 - val_loss: 0.2701 - val_accuracy: 0.9032
Epoch 61/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3207 - accuracy: 0.8676 - val_loss: 0.2714 - val_accuracy: 0.9032
Epoch 62/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3168 - accuracy: 0.8640 - val_loss: 0.2727 - val_accuracy: 0.9032
Epoch 63/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3150 - accuracy: 0.8640 - val_loss: 0.2709 - val_accuracy: 0.9032
Epoch 64/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3139 - accuracy: 0.8860 - val_loss: 0.2688 - val_accuracy: 0.8710
Epoch 65/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3130 - accuracy: 0.8934 - val_loss: 0.2700 - val_accuracy: 0.8710
Epoch 66/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3118 - accuracy: 0.8824 - val_loss: 0.2725 - val_accuracy: 0.8710
Epoch 67/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3077 - accuracy: 0.8897 - val_loss: 0.2765 - val_accuracy: 0.8710
Epoch 68/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3050 - accuracy: 0.8934 - val_loss: 0.2801 - val_accuracy: 0.9032
Epoch 69/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3022 - accuracy: 0.8897 - val_loss: 0.2821 - val_accuracy: 0.9032
Epoch 70/100
3/3 [==============================] - 0s 13ms/step - loss: 0.3002 - accuracy: 0.8897 - val_loss: 0.2837 - val_accuracy: 0.9032
Epoch 71/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2981 - accuracy: 0.8934 - val_loss: 0.2852 - val_accuracy: 0.8710
Epoch 72/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2973 - accuracy: 0.8860 - val_loss: 0.2870 - val_accuracy: 0.8710
Epoch 73/100
3/3 [==============================] - 0s 12ms/step - loss: 0.2962 - accuracy: 0.8860 - val_loss: 0.2873 - val_accuracy: 0.8710
Epoch 74/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2954 - accuracy: 0.8897 - val_loss: 0.2849 - val_accuracy: 0.8710
Epoch 75/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2909 - accuracy: 0.8860 - val_loss: 0.2865 - val_accuracy: 0.9032
Epoch 76/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2867 - accuracy: 0.8897 - val_loss: 0.2942 - val_accuracy: 0.9032
Epoch 77/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2888 - accuracy: 0.8824 - val_loss: 0.3043 - val_accuracy: 0.9032
Epoch 78/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2932 - accuracy: 0.8713 - val_loss: 0.3046 - val_accuracy: 0.9032
Epoch 79/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2871 - accuracy: 0.8824 - val_loss: 0.2997 - val_accuracy: 0.8710
Epoch 80/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2799 - accuracy: 0.8787 - val_loss: 0.2997 - val_accuracy: 0.8387
Epoch 81/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2790 - accuracy: 0.8860 - val_loss: 0.2980 - val_accuracy: 0.8387
Epoch 82/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2782 - accuracy: 0.8934 - val_loss: 0.2978 - val_accuracy: 0.8387
Epoch 83/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2758 - accuracy: 0.9007 - val_loss: 0.3000 - val_accuracy: 0.8710
Epoch 84/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2752 - accuracy: 0.8897 - val_loss: 0.3061 - val_accuracy: 0.9032
Epoch 85/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2794 - accuracy: 0.8787 - val_loss: 0.3143 - val_accuracy: 0.9032
Epoch 86/100
3/3 [==============================] - 0s 14ms/step - loss: 0.2809 - accuracy: 0.8750 - val_loss: 0.3126 - val_accuracy: 0.9032
Epoch 87/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2748 - accuracy: 0.8824 - val_loss: 0.3098 - val_accuracy: 0.8710
Epoch 88/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2704 - accuracy: 0.8860 - val_loss: 0.3121 - val_accuracy: 0.8387
Epoch 89/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2697 - accuracy: 0.8860 - val_loss: 0.3158 - val_accuracy: 0.8387
Epoch 90/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2668 - accuracy: 0.8897 - val_loss: 0.3152 - val_accuracy: 0.8387
Epoch 91/100
3/3 [==============================] - 0s 14ms/step - loss: 0.2619 - accuracy: 0.8971 - val_loss: 0.3177 - val_accuracy: 0.8387
Epoch 92/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2609 - accuracy: 0.8897 - val_loss: 0.3207 - val_accuracy: 0.8387
Epoch 93/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2616 - accuracy: 0.8934 - val_loss: 0.3216 - val_accuracy: 0.8387
Epoch 94/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2604 - accuracy: 0.8897 - val_loss: 0.3249 - val_accuracy: 0.8387
Epoch 95/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2569 - accuracy: 0.8971 - val_loss: 0.3252 - val_accuracy: 0.8387
Epoch 96/100
3/3 [==============================] - 0s 13ms/step - loss: 0.2556 - accuracy: 0.9007 - val_loss: 0.3199 - val_accuracy: 0.8387
Epoch 97/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2542 - accuracy: 0.9081 - val_loss: 0.3121 - val_accuracy: 0.8387
Epoch 98/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2553 - accuracy: 0.9081 - val_loss: 0.3088 - val_accuracy: 0.8387
Epoch 99/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2509 - accuracy: 0.9081 - val_loss: 0.3095 - val_accuracy: 0.9032
Epoch 100/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2504 - accuracy: 0.8971 - val_loss: 0.3119 - val_accuracy: 0.9032
七、模型评估
import matplotlib.pyplot as plt
acc=history.history['accuracy']
val_acc=history.history['val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']
epochs_range=range(epochs)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,acc,label='Training Accuracy')
plt.plot(epochs_range,val_acc,label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label='Training Loss')
plt.plot(epochs_range,val_loss,label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
scores=model.evaluate(x_test,y_test,verbose=0)
print("%s:%.2f%%" % (model.metrics_names[1],scores[1]*100))
运行结果:
accuracy:90.32%
八、心得体会
学习了什么是RNN,并且在TensorFlow环境下构建了简单的RNN模型,本次的训练结果恰巧达到90%,中间也通过修改学习率等尝试提升准确率,但结果都差强人意,留待以后学习过程中构建更为复杂的网络模型提升准确率。