当前位置: 首页 > article >正文

第32周:猴痘病识别(Tensorflow实战第四周)

目录

前言

一、前期工作

1.1 设置GPU

1.2 导入数据

1.3 查看数据

二、数据预处理

2.1 加载数据

2.2 可视化数据

2.3 再次检查数据

2.4 配置数据集

2.4.1 基本概念介绍

2.4.2.代码完成

三、构建CNN网络

四、编译

五、训练模型

六、模型评估

6.1 Loss和Accuracy图

6.2 指定图片进行预测

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

1)本周任务:基于CNN模型完成对猴痘病图片的识别

2)运行环境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 设置GPU

代码如下:

# 一、前期准备
# 1.1 设置GPU
from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)

1.2 导入数据

代码如下:

# 1.2 导入数据
data_dir = "./4-data/"
data_dir = pathlib.Path(data_dir)

1.3 查看数据

代码如下:

# 1.3 查看数据
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
Monkeypox = list(data_dir.glob('Monkeypox/*.jpg'))
PIL.Image.open(str(Monkeypox[0]))

输出:

图片总数为: 2142

二、数据预处理

2.1 加载数据

使用image_dataset_from_directory方法(详细可参考文章tf.keras.preprocessing.image_dataset_from_directory() 简介_tf.python.keras preprocessing在哪里-CSDN博客)将磁盘中的数据加载到tf.data.Dataset中。

测试集与验证集的关系:

  • 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  • 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

# 二、数据预处理
# 2.1 加载数据
batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

输出如下:

Found 2142 files belonging to 2 classes.
Using 1714 files for training.

​Found 2142 files belonging to 2 classes.
Using 428 files for validation.

['Monkeypox', 'Others']

2.2 可视化数据

代码如下:

# 2.2 可视化数据
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

输出:

2.3 再次检查数据

代码如下:

# 2.3再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(32, 224, 224, 3)
(32,)

情况说明:

  • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

2.4 配置数据集

2.4.1 基本概念介绍

prefetch():CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态

然后使用prefetch()可显著减少空闲时间:

2.4.2.代码完成

代码如下:

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

         卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

网络结构图如下:

代码如下:

# 三、构建CNN网络
num_classes = 4
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),  # 让神经元以一定的概率停止工作,防止过拟合,提高模型的泛化能力。

    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(num_classes)  # 输出层,输出预期结果
])
model.summary()  # 打印网络结构

模型结构打印如下:

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

代码如下:

# 四、编译
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

代码如下:

# 五、训练模型
epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

训练过程打印如下:

六、模型评估

6.1 Loss和Accuracy图

代码如下:

# 六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

训练结果可视化如下:

6.2 指定图片进行预测

代码如下:

# 6.2 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np
# img = Image.open("./4-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./4-data/Others/NM15_02_11.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])
img_array = tf.expand_dims(image, 0)
predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

输出:

预测结果为: Others


总结

本周在上周的基础上增加了对指定图片进行预测的任务,并实现了正确预测,也让我更加熟练CNN模型搭建的流程


http://www.kler.cn/a/414382.html

相关文章:

  • Lerna管理和发布同一源码仓库的多个js/ts包
  • SNMPv3 项目实例
  • 如何借助AI生成PPT,让创作轻松又高效
  • Mybatis-基础操作
  • [Unity] 【游戏开发】角色设计3-如何为角色实现响应输入的控制器
  • PMP每日一练(三十八)
  • GitLab历史演进
  • 组成无重复数字的三位数
  • 输入一行字符,分别统计出其中英文字母、空格、数字和其它字符的个数。-多语言
  • 第02章 使用VMware部署CENTOS系统
  • SqlServer强制转换函数TRY_CONVERT和TRY_CAST
  • “小bug”示例
  • 一款现代化的轻量级跨平台Redis桌面客户端
  • 大数据机器学习算法与计算机视觉应用05:乘法权重算法
  • 【第十二课】Rust并发编程(三)
  • NodeFormer:一种用于节点分类的可扩展图结构学习 Transformer
  • 修改element UI el-table背景颜色样式 input select date vuetree
  • 如何在 IIS 上部署 .NET Core 应用程序 ?
  • 基于 Flask 和 Socket.IO 的 WebSocket 实时数据更新实现
  • 常用Python集成开发环境(IDE)
  • 基于FPGA的SD NAND读写测试(图文并茂+源代码+详细注释)
  • ISIS SSN/SRM 标志在 P2P 链路和 Broadcast 链路中的作用
  • Python全局解释器锁(GIL)深度解析
  • 现代化水库可视化管理平台:提升水库运行效率与安全保障
  • docker的joinsunsoft/docker.ui修改密码【未解决】
  • 二十六:Web条件请求的作用