当前位置: 首页 > article >正文

TensorFlow_T7 咖啡豆识别

目录

一、前言

二、前期准备

1、设置GPU

2、导入数据

3、查看数据图片

三、数据预处理 

1、加载数据

2、可视化数据

3、配置数据集

四、构建VGG-16网络

1、VGG优缺点分析

2、自建模型

3、网络结构图

五、编译

六、 训练模型


一、前言

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

二、前期准备

1、设置GPU

import tensorflow as tf
 
gpus=tf.config.list_physical_devices("GPU")
 
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0],True)
    tf.config.set_visible_devices([gpus[0]],"GPU")
gpus

2、导入数据

import pathlib
 
data_dir="D:\THE MNIST DATABASE\T7"
data_dir=pathlib.Path(data_dir)

3、查看数据图片

image_count=len(list(data_dir.glob('*/*.png')))
 
print("图片总数为:",image_count)

运行结果:

 

三、数据预处理 

1、加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。

train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(224,224),
    batch_size=32
)

运行结果如下:

加载验证集:

val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(224,224),
    batch_size=32
)

运行结果如下:

通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names=train_ds.class_names
print(class_names)

 运行结果如下:

2、可视化数据

import matplotlib.pyplot as plt
 
plt.figure(figsize=(10,4))
 
for images,labels in train_ds.take(1):
    for i in range(10):
        ax=plt.subplot(2,5,i+1)
        
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

运行结果如下:

查看图像格式:

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

运行结果如下:

3、配置数据集

  • shuffle() :打乱数据;
  • prefetch() :预取数据,加速运行;
  • cache() :将数据集缓存到内存当中,加速运行;

对数据集进行预处理,对于验证集只进行了缓存和预取操作,没有进行打乱操作。

AUTOTUNE=tf.data.AUTOTUNE
 
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

对图像数据集进行归一化。 将输入数据除以255,将像素值缩放到0到1之间。然后,使用map函数将这个归一化层应用到训练数据集train_ds和验证数据集val_ds的每个样本上。这样,所有的图像都会被归一化,以便在神经网络中更好地处理。

from tensorflow.keras import layers
 
normalization_layer=layers.experimental.preprocessing.Rescaling(1./255)
 
train_ds=train_ds.map(lambda x,y:(normalization_layer(x),y))
val_ds=val_ds.map(lambda x,y:(normalization_layer(x),y))

从验证数据集中获取一个批次的图像和标签,然后将第一个图像存储在变量first_image中。接下来,使用numpy库的min和max函数分别计算first_image中的最小值和最大值,并将它们打印出来。这样可以帮助我们了解图像数据的归一化情况,例如是否所有像素值都在0到1之间。

import numpy as np
 
image_batch,labels_batch=next(iter(val_ds))
first_image=image_batch[0]
 
#查看归一化后的数据
print(np.min(first_image),np.max(first_image))

运行结果如下:

四、构建VGG-16网络

1、VGG优缺点分析

(1)优点:结构简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2);

(2)缺点

  • 训练时间过长,调参难度大;
  • 需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中;

2、自建模型

from tensorflow.keras import layers,models,Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Dropout
 
def vgg16(nb_classes,input_shape):
    input_tensor=Input(shape=input_shape)
    #1st block
    x=Conv2D(64,(3,3),activation='relu',padding='same')(input_tensor)
    x=Conv2D(64,(3,3),activation='relu',padding='same')(x)
    x=MaxPooling2D((2,2),strides=(2,2))(x)
    #2nd block
    x=Conv2D(128,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(128,(3,3),activation='relu',padding='same')(x)
    x=MaxPooling2D((2,2),strides=(2,2))(x)
    #3rd block
    x=Conv2D(256,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(256,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(256,(3,3),activation='relu',padding='same')(x)
    x=MaxPooling2D((2,2),strides=(2,2))(x)
    #4th block
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=MaxPooling2D((2,2),strides=(2,2))(x)
    #5th block
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=Conv2D(512,(3,3),activation='relu',padding='same')(x)
    x=MaxPooling2D((2,2),strides=(2,2))(x)
    #full connection
    x=Flatten()(x)
    x=Dense(4096,activation='relu')(x)
    x=Dense(4096,activation='relu')(x)
    output_tensor=Dense(nb_classes,activation='softmax',name='predictions')(x)
    
    model=Model(input_tensor,output_tensor)
    return model
 
model=vgg16(len(class_names),(224,224,3))
model.summary()

运行结果如下:

   

   

3、网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示;
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示;
  • 5个池化层(Pool layer),分别用blockX_pool表示;

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16;

五、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率;
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新;
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率;
#设置初始学习率
initial_learning_rate=1e-4
 
lr_schedule=tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=30,
    decay_rate=0.92,
    staircase=True
)
 
#设置优化器
opt=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
 
model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

六、 训练模型

epochs=20
 
history=model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

运行结果如图:

七、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
 
loss = history.history['loss']
val_loss = history.history['val_loss']
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果如图:


👏觉得文章对自己有用的宝子可以收藏文章并给小编点个赞!

👏想了解更多统计学、数据分析、数据开发、机器学习算法、深度学习等有关知识的宝子们,可以关注小编,希望以后我们一起成长!


http://www.kler.cn/a/394380.html

相关文章:

  • 深入探索:Scrapy深度爬取策略与实践
  • @ComponentScan:Spring Boot中的自动装配大师
  • 前端垂直居中的多种实现方式及应用分析
  • Android Framework AMS(16)进程管理
  • 时序数据库TimescaleDB安装部署以及常见使用
  • 设计模式:工厂方法模式和策略模式
  • JavaEE-多线程初阶(5)
  • 自定义反序列化过程
  • 【金猿人物展】罗格科技CTO崔鹏——数据驱动未来:从2024看2025大数据行业的变革与挑战...
  • shell 100例
  • STM32中断系统
  • 库存管理高效秘籍
  • ubuntu的dns设置问题
  • 从ROS Bag文件提取点云数据并保存为PCD格式进行处理 ros ubuntu
  • 15分钟学 Go 第 52 天 :发布与版本控制
  • 如何将Edge标签页设置得干净好用
  • Docker部署Nginx
  • 【C语言】计算3x3矩阵每行的最大值并存入第四列
  • 解密复杂系统:理论、模型与案例(3)
  • Fantasy中玩家断线的检测
  • C语言的内存函数
  • 【LeetCode】【算法】538. 把二叉搜索树转换为累加树
  • 【IC每日一题:IC常用模块--RR/handshake/gray2bin】
  • SSH是 struts+spring+hibernate集成框架
  • 政务数据治理专栏开搞!
  • 浏览器是加载ES6模块的?