当前位置: 首页 > article >正文

第R3周:RNN-心脏病预测(TensorFlow版)

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

🍺 要求:

  1. 找到并处理第8周的程序问题(本文给出了答案)
  2. 了解循环神经网络(RNN)的构建过程。
  3. 测试集accuracy到达87%。

🍻 拔高(可选):

  1. 测试集accuracy到达89%

往期文章可查阅: 深度学习总结

🚀我的环境:

  • 语言环境:Python3.11.7
  • 编译器:jupyter notebook
  • 深度学习框架:TensorFlow2.13.0

代码流程图如下所示:

一、RNN是什么

传统神经网络的结构比较简单:输入层 – 隐藏层 – 输出层

RNN 跟传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示:

这里用一个具体的案例来看看 RNN 是如何工作的:用户说了一句“what time is it?”,我们的神经网络会先将这句话分为五个基本单元(四个单词+一个问号)

然后,按照顺序将五个基本单元输入RNN网络,先将 “what”作为RNN的输入,得到输出 01

随后,按照顺序将“time”输入到RNN网络,得到输出02。

这个过程我们可以看到,输入 “time” 的时候,前面“what” 的输出也会对02的输出产生了影响(隐藏层中有一半是黑色的)。

以此类推,我们可以看到,前面所有的输入产生的结果都对后续的输出产生了影响(可以看到圆形中包含了前面所有的颜色)。

当神经网络判断意图的时候,只需要最后一层的输出 05,如下图所示:

二、前期准备

1. 设置GPU

import tensorflow as tf

gpus=tf.config.list_physical_devices("GPU")

if gpus:
    gpu0=gpus[0]  #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0,True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus

运行结果:

[]

因为我的TensorFlow没有安装GPU版本,故显示此。

2. 导入数据

数据介绍:

●age:1) 年龄
●sex:2) 性别
●cp:3) 胸痛类型 (4 values)
●trestbps:4) 静息血压
●chol:5) 血清胆甾醇 (mg/dl
●fbs:6) 空腹血糖 > 120 mg/dl
●restecg:7) 静息心电图结果 (值 0,1 ,2)
●thalach:8) 达到的最大心率
●exang:9) 运动诱发的心绞痛
●oldpeak:10) 相对于静止状态,运动引起的ST段压低
●slope:11) 运动峰值 ST 段的斜率
●ca:12) 荧光透视着色的主要血管数量 (0-3)
●thal:13) 0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
●target:14) 0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大

import pandas as pd
import numpy as np

df=pd.read_csv("D:\THE MNIST DATABASE\RNN\R3\heart.csv")
df

运行结果:

3. 检查数据

# 检查是否有空值
df.isnull().sum()

 运行结果:

age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

三、数据预处理

1. 划分训练集与测试集

测试集与验证集的关系:

(1)验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。

(2)但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
(3)我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集。

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

x=df.iloc[:,:-1]
y=df.iloc[:,-1]

x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.1,random_state=1)
x_train.shape,y_train.shape

运行结果:

((272, 13), (272,))

2. 标准化

# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc=StandardScaler()
x_train=sc.fit_transform(x_train)
x_test=sc.transform(x_test)

x_train=x_train.reshape(x_train.shape[0],x_train.shape[1],1)
x_test=x_test.reshape(x_test.shape[0],x_test.shape[1],1)

四、构建RNN模型

函数原型:

tf.keras.layers.SimpleRNN(units,activation=‘tanh’,use_bias=True,kernel_initializer=‘glorot_uniform’,
recurrent_initializer=‘orthogonal’,bias_initializer=‘zeros’,kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,recurrent_constraint=None,
bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,
go_backwards=False,stateful=False,unroll=False,**kwargs)

关键参数说明:

●units: 正整数,输出空间的维度。
●activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。
●use_bias: 布尔值,该层是否使用偏置向量。
●kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。
●recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。
●bias_initializer:偏置向量的初始化器 (详见initializers)。
●dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model=Sequential()
model.add(SimpleRNN(200,input_shape=(13,1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()

 运行结果:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn (SimpleRNN)      (None, 200)               40400     
                                                                 
 dense (Dense)               (None, 100)               20100     
                                                                 
 dense_1 (Dense)             (None, 1)                 101       
                                                                 
=================================================================
Total params: 60601 (236.72 KB)
Trainable params: 60601 (236.72 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

五、编译模型

opt=tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics="accuracy")

六、训练模型

epochs=100

history=model.fit(x_train,y_train,
                  epochs=epochs,
                  batch_size=128,
                  validation_data=(x_test,y_test),
                  verbose=1)

运行结果:

Epoch 1/100
3/3 [==============================] - 1s 140ms/step - loss: 0.6778 - accuracy: 0.7206 - val_loss: 0.6456 - val_accuracy: 0.8710
Epoch 2/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6669 - accuracy: 0.7721 - val_loss: 0.6315 - val_accuracy: 0.8710
Epoch 3/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6576 - accuracy: 0.7794 - val_loss: 0.6184 - val_accuracy: 0.8387
Epoch 4/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6489 - accuracy: 0.7794 - val_loss: 0.6051 - val_accuracy: 0.8710
Epoch 5/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6399 - accuracy: 0.7721 - val_loss: 0.5921 - val_accuracy: 0.8710
Epoch 6/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6309 - accuracy: 0.7757 - val_loss: 0.5785 - val_accuracy: 0.8710
Epoch 7/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6221 - accuracy: 0.7757 - val_loss: 0.5644 - val_accuracy: 0.8710
Epoch 8/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6133 - accuracy: 0.7757 - val_loss: 0.5502 - val_accuracy: 0.8710
Epoch 9/100
3/3 [==============================] - 0s 15ms/step - loss: 0.6037 - accuracy: 0.7757 - val_loss: 0.5355 - val_accuracy: 0.8710
Epoch 10/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5932 - accuracy: 0.7794 - val_loss: 0.5197 - val_accuracy: 0.8387
Epoch 11/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5832 - accuracy: 0.7831 - val_loss: 0.5031 - val_accuracy: 0.8387
Epoch 12/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5717 - accuracy: 0.7941 - val_loss: 0.4862 - val_accuracy: 0.8387
Epoch 13/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5595 - accuracy: 0.7941 - val_loss: 0.4670 - val_accuracy: 0.8387
Epoch 14/100
3/3 [==============================] - 0s 14ms/step - loss: 0.5462 - accuracy: 0.8015 - val_loss: 0.4456 - val_accuracy: 0.8710
Epoch 15/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5318 - accuracy: 0.8088 - val_loss: 0.4226 - val_accuracy: 0.8710
Epoch 16/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5166 - accuracy: 0.8015 - val_loss: 0.3984 - val_accuracy: 0.8387
Epoch 17/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5000 - accuracy: 0.8015 - val_loss: 0.3757 - val_accuracy: 0.8387
Epoch 18/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4845 - accuracy: 0.8088 - val_loss: 0.3550 - val_accuracy: 0.8387
Epoch 19/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4694 - accuracy: 0.8088 - val_loss: 0.3355 - val_accuracy: 0.8710
Epoch 20/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4545 - accuracy: 0.8015 - val_loss: 0.3177 - val_accuracy: 0.8710
Epoch 21/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4425 - accuracy: 0.8015 - val_loss: 0.3035 - val_accuracy: 0.8710
Epoch 22/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4350 - accuracy: 0.8015 - val_loss: 0.2928 - val_accuracy: 0.8710
Epoch 23/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4264 - accuracy: 0.8015 - val_loss: 0.2856 - val_accuracy: 0.8710
Epoch 24/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4199 - accuracy: 0.7978 - val_loss: 0.2840 - val_accuracy: 0.9032
Epoch 25/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4175 - accuracy: 0.8088 - val_loss: 0.2795 - val_accuracy: 0.9032
Epoch 26/100
3/3 [==============================] - 0s 15ms/step - loss: 0.4127 - accuracy: 0.8051 - val_loss: 0.2726 - val_accuracy: 0.9032
Epoch 27/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4080 - accuracy: 0.8088 - val_loss: 0.2675 - val_accuracy: 0.8710
Epoch 28/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4088 - accuracy: 0.8125 - val_loss: 0.2663 - val_accuracy: 0.9032
Epoch 29/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4026 - accuracy: 0.8235 - val_loss: 0.2671 - val_accuracy: 0.9032
Epoch 30/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3955 - accuracy: 0.8125 - val_loss: 0.2701 - val_accuracy: 0.9032
Epoch 31/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3929 - accuracy: 0.8162 - val_loss: 0.2703 - val_accuracy: 0.9032
Epoch 32/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3905 - accuracy: 0.8199 - val_loss: 0.2686 - val_accuracy: 0.9032
Epoch 33/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3877 - accuracy: 0.8199 - val_loss: 0.2631 - val_accuracy: 0.9032
Epoch 34/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3832 - accuracy: 0.8235 - val_loss: 0.2568 - val_accuracy: 0.9032
Epoch 35/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3843 - accuracy: 0.8162 - val_loss: 0.2560 - val_accuracy: 0.9032
Epoch 36/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3809 - accuracy: 0.8199 - val_loss: 0.2577 - val_accuracy: 0.9032
Epoch 37/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3752 - accuracy: 0.8199 - val_loss: 0.2602 - val_accuracy: 0.9032
Epoch 38/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3718 - accuracy: 0.8309 - val_loss: 0.2629 - val_accuracy: 0.9032
Epoch 39/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3694 - accuracy: 0.8235 - val_loss: 0.2622 - val_accuracy: 0.9032
Epoch 40/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3666 - accuracy: 0.8272 - val_loss: 0.2601 - val_accuracy: 0.9032
Epoch 41/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3655 - accuracy: 0.8309 - val_loss: 0.2594 - val_accuracy: 0.9032
Epoch 42/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3643 - accuracy: 0.8346 - val_loss: 0.2587 - val_accuracy: 0.9032
Epoch 43/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3600 - accuracy: 0.8382 - val_loss: 0.2610 - val_accuracy: 0.9032
Epoch 44/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3568 - accuracy: 0.8382 - val_loss: 0.2637 - val_accuracy: 0.9032
Epoch 45/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3560 - accuracy: 0.8346 - val_loss: 0.2608 - val_accuracy: 0.9032
Epoch 46/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3527 - accuracy: 0.8382 - val_loss: 0.2563 - val_accuracy: 0.9032
Epoch 47/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3506 - accuracy: 0.8382 - val_loss: 0.2541 - val_accuracy: 0.9032
Epoch 48/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3482 - accuracy: 0.8419 - val_loss: 0.2542 - val_accuracy: 0.9032
Epoch 49/100
3/3 [==============================] - 0s 14ms/step - loss: 0.3457 - accuracy: 0.8419 - val_loss: 0.2560 - val_accuracy: 0.9032
Epoch 50/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3418 - accuracy: 0.8456 - val_loss: 0.2558 - val_accuracy: 0.9032
Epoch 51/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3401 - accuracy: 0.8529 - val_loss: 0.2554 - val_accuracy: 0.9032
Epoch 52/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3381 - accuracy: 0.8529 - val_loss: 0.2577 - val_accuracy: 0.9032
Epoch 53/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3354 - accuracy: 0.8529 - val_loss: 0.2608 - val_accuracy: 0.9032
Epoch 54/100
3/3 [==============================] - 0s 13ms/step - loss: 0.3337 - accuracy: 0.8603 - val_loss: 0.2611 - val_accuracy: 0.9032
Epoch 55/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3318 - accuracy: 0.8603 - val_loss: 0.2628 - val_accuracy: 0.9032
Epoch 56/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3302 - accuracy: 0.8640 - val_loss: 0.2666 - val_accuracy: 0.9032
Epoch 57/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3292 - accuracy: 0.8603 - val_loss: 0.2669 - val_accuracy: 0.9032
Epoch 58/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3261 - accuracy: 0.8640 - val_loss: 0.2655 - val_accuracy: 0.9032
Epoch 59/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3231 - accuracy: 0.8640 - val_loss: 0.2669 - val_accuracy: 0.9032
Epoch 60/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3219 - accuracy: 0.8640 - val_loss: 0.2701 - val_accuracy: 0.9032
Epoch 61/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3207 - accuracy: 0.8676 - val_loss: 0.2714 - val_accuracy: 0.9032
Epoch 62/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3168 - accuracy: 0.8640 - val_loss: 0.2727 - val_accuracy: 0.9032
Epoch 63/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3150 - accuracy: 0.8640 - val_loss: 0.2709 - val_accuracy: 0.9032
Epoch 64/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3139 - accuracy: 0.8860 - val_loss: 0.2688 - val_accuracy: 0.8710
Epoch 65/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3130 - accuracy: 0.8934 - val_loss: 0.2700 - val_accuracy: 0.8710
Epoch 66/100
3/3 [==============================] - 0s 15ms/step - loss: 0.3118 - accuracy: 0.8824 - val_loss: 0.2725 - val_accuracy: 0.8710
Epoch 67/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3077 - accuracy: 0.8897 - val_loss: 0.2765 - val_accuracy: 0.8710
Epoch 68/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3050 - accuracy: 0.8934 - val_loss: 0.2801 - val_accuracy: 0.9032
Epoch 69/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3022 - accuracy: 0.8897 - val_loss: 0.2821 - val_accuracy: 0.9032
Epoch 70/100
3/3 [==============================] - 0s 13ms/step - loss: 0.3002 - accuracy: 0.8897 - val_loss: 0.2837 - val_accuracy: 0.9032
Epoch 71/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2981 - accuracy: 0.8934 - val_loss: 0.2852 - val_accuracy: 0.8710
Epoch 72/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2973 - accuracy: 0.8860 - val_loss: 0.2870 - val_accuracy: 0.8710
Epoch 73/100
3/3 [==============================] - 0s 12ms/step - loss: 0.2962 - accuracy: 0.8860 - val_loss: 0.2873 - val_accuracy: 0.8710
Epoch 74/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2954 - accuracy: 0.8897 - val_loss: 0.2849 - val_accuracy: 0.8710
Epoch 75/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2909 - accuracy: 0.8860 - val_loss: 0.2865 - val_accuracy: 0.9032
Epoch 76/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2867 - accuracy: 0.8897 - val_loss: 0.2942 - val_accuracy: 0.9032
Epoch 77/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2888 - accuracy: 0.8824 - val_loss: 0.3043 - val_accuracy: 0.9032
Epoch 78/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2932 - accuracy: 0.8713 - val_loss: 0.3046 - val_accuracy: 0.9032
Epoch 79/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2871 - accuracy: 0.8824 - val_loss: 0.2997 - val_accuracy: 0.8710
Epoch 80/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2799 - accuracy: 0.8787 - val_loss: 0.2997 - val_accuracy: 0.8387
Epoch 81/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2790 - accuracy: 0.8860 - val_loss: 0.2980 - val_accuracy: 0.8387
Epoch 82/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2782 - accuracy: 0.8934 - val_loss: 0.2978 - val_accuracy: 0.8387
Epoch 83/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2758 - accuracy: 0.9007 - val_loss: 0.3000 - val_accuracy: 0.8710
Epoch 84/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2752 - accuracy: 0.8897 - val_loss: 0.3061 - val_accuracy: 0.9032
Epoch 85/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2794 - accuracy: 0.8787 - val_loss: 0.3143 - val_accuracy: 0.9032
Epoch 86/100
3/3 [==============================] - 0s 14ms/step - loss: 0.2809 - accuracy: 0.8750 - val_loss: 0.3126 - val_accuracy: 0.9032
Epoch 87/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2748 - accuracy: 0.8824 - val_loss: 0.3098 - val_accuracy: 0.8710
Epoch 88/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2704 - accuracy: 0.8860 - val_loss: 0.3121 - val_accuracy: 0.8387
Epoch 89/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2697 - accuracy: 0.8860 - val_loss: 0.3158 - val_accuracy: 0.8387
Epoch 90/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2668 - accuracy: 0.8897 - val_loss: 0.3152 - val_accuracy: 0.8387
Epoch 91/100
3/3 [==============================] - 0s 14ms/step - loss: 0.2619 - accuracy: 0.8971 - val_loss: 0.3177 - val_accuracy: 0.8387
Epoch 92/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2609 - accuracy: 0.8897 - val_loss: 0.3207 - val_accuracy: 0.8387
Epoch 93/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2616 - accuracy: 0.8934 - val_loss: 0.3216 - val_accuracy: 0.8387
Epoch 94/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2604 - accuracy: 0.8897 - val_loss: 0.3249 - val_accuracy: 0.8387
Epoch 95/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2569 - accuracy: 0.8971 - val_loss: 0.3252 - val_accuracy: 0.8387
Epoch 96/100
3/3 [==============================] - 0s 13ms/step - loss: 0.2556 - accuracy: 0.9007 - val_loss: 0.3199 - val_accuracy: 0.8387
Epoch 97/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2542 - accuracy: 0.9081 - val_loss: 0.3121 - val_accuracy: 0.8387
Epoch 98/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2553 - accuracy: 0.9081 - val_loss: 0.3088 - val_accuracy: 0.8387
Epoch 99/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2509 - accuracy: 0.9081 - val_loss: 0.3095 - val_accuracy: 0.9032
Epoch 100/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2504 - accuracy: 0.8971 - val_loss: 0.3119 - val_accuracy: 0.9032

七、模型评估

import matplotlib.pyplot as plt

acc=history.history['accuracy']
val_acc=history.history['val_accuracy']

loss=history.history['loss']
val_loss=history.history['val_loss']

epochs_range=range(epochs)

plt.figure(figsize=(14,4))
plt.subplot(1,2,1)

plt.plot(epochs_range,acc,label='Training Accuracy')
plt.plot(epochs_range,val_acc,label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label='Training Loss')
plt.plot(epochs_range,val_loss,label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

scores=model.evaluate(x_test,y_test,verbose=0)
print("%s:%.2f%%" % (model.metrics_names[1],scores[1]*100))

 运行结果:

accuracy:90.32%

八、心得体会

学习了什么是RNN,并且在TensorFlow环境下构建了简单的RNN模型,本次的训练结果恰巧达到90%,中间也通过修改学习率等尝试提升准确率,但结果都差强人意,留待以后学习过程中构建更为复杂的网络模型提升准确率。


http://www.kler.cn/a/391330.html

相关文章:

  • StarRocks一次复杂查询引起的Planner超时异常
  • OpenAI 普及 ChatGPT,开通热线电话,近屿智能深耕AI培训
  • Kafka快速扫描
  • 【uni-app】2025最新uni-app一键登录保姆级教程(包含前后端获取手机号方法)(超强避坑指南)
  • 【RAG实战】Prompting vs. RAG vs. Finetuning: 如何选择LLM应用选择最佳方案
  • Python入门:4.Python中的运算符
  • JavaWeb--SpringBoot
  • 计算机网络基础:从IP地址到分层模型
  • 边缘计算在智能物流中的应用
  • golang 实现比特币内核:数字签名的编码算法
  • ctfshow(319->326)--XSS漏洞--反射型XSS
  • Xcode 16 使用 pod 命令报错解决方案
  • VMware Fusion和centos 8的安装
  • 【MySQL】关于MySQL启动后mysqld_safe和mysqld进程
  • Python酷库之旅-第三方库Pandas(208)
  • 【LinuxC编程】06 - 守护进程,线程
  • 基于深度学习的猫狗识别
  • 力扣102:二叉树的层次遍历
  • OpenEuler 下 Docker 安装、配置与测试实例
  • [数组二分查找] 0153. 寻找旋转排序数组中最小值
  • Vite初始化Vue3+Typescrpt项目
  • C#自定义特性-SQL
  • 如何在 Ubuntu 上 部署 OceanBase
  • CosyVoice文本转语音:轻松创造个性化音频
  • 【LeetCode每日一题】——LCR 106.判断二分图
  • 自动化爬虫DrissionPage