当前位置: 首页 > article >正文

TensorFlow_T11 优化器对比实验

目录

一、前言

二、前期准备

1、设置GPU

2、导入数据

三、数据预处理

1、加载数据

2、再次检查数据

 3、配置数据集

4、数据可视化

四、构建网络

五、训练模型

六、模型评估

1、Loss与Accuracy图

 ​编辑  2、模型评估


一、前言

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

二、前期准备

1、设置GPU

import tensorflow as tf
 
gpus = tf.config.list_physical_devices("GPU")
 
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
 
# 打印显卡信息,确认GPU可用
print(gpus)

2、导入数据

import matplotlib.pyplot as plt
import warnings,pathlib
 
warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号
 
data_dir    = "../data-2/data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

运行结果如下:

三、数据预处理

1、加载数据

设置基本的图片格式

batch_size = 16
img_height = 336
img_width  = 336

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中 

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
 
 
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

运行结果如下:

可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称

class_names = train_ds.class_names
print(class_names)

2、再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

运行结果如下: 

  • Image_batch是形状的张量(16, 336, 336, 3)。这是一批形状336x336x3的8张图片(最后一维指的是彩色通道RGB);
  • Label_batch是形状(16,)的张量,这些标签对应16张图片;

 3、配置数据集

  • shuffle() : 打乱数据

  • prefetch() :预取数据,加速运行

  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
 
def train_preprocessing(image,label):
    return (image/255.0,label)
 
train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)
 
val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")
 
for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
 
        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])
 
plt.show()

运行结果如下

   

四、构建网络

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model
 
def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False
 
    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)
 
    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)
 
    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model
 
model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

运行结果如下

   这里使用了预训练的VGG16模型,并加载了在ImageNet数据集上的预训练权重,移除了分类部分,在后续将VGG16作为新模型的输入,并添加了对应的层及泛化措施,并实现了模型的编译过程。

五、训练模型

NO_EPOCHS = 50
 
history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

运行结果如下

   

六、模型评估

1、Loss与Accuracy图

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率
 
acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']
 
loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']
 
epochs_range = range(len(acc1))
 
plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)
 
plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
 
plt.show()

可视化结果如下

   2、模型评估

def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)

运行结果如下

  


👏觉得文章对自己有用的宝子可以收藏文章并给小编点个赞!

👏想了解更多统计学、数据分析、数据开发、机器学习算法、深度学习等有关知识的宝子们,可以关注小编,希望以后我们一起成长!


http://www.kler.cn/a/446507.html

相关文章:

  • 【Prompt Engineering】6 文本扩展
  • W25Q128读写实验(一)
  • App自动化之dom结构和元素定位方式(包含滑动列表定位)
  • 使用k6进行MongoDB负载测试
  • es使用knn向量检索中numCandidates和k应该如何配比更合适
  • C++版实用时间戳类(Timestamp)
  • 用docker快速安装电子白板Excalidraw绘制流程图
  • GaussDB数据库中SQL诊断解析之配置SQL限流
  • Bcrypt在线密码加密生成器
  • 【人工智能】用Python实现图卷积网络(GCN):从理论到节点分类实战
  • 【网络云计算】2024第51周-每日【2024/12/20】小测-理论-周测
  • WeakAuras NES Script(lua)
  • 【微信小程序开发 - 3】:项目组成介绍
  • 易快报-飞书-金蝶云星空集成项目技术分享
  • QoS 流分类
  • 云手机有哪些用途?云手机选择推荐
  • leetcode 面试经典 150 题:长度最小的子数组
  • HCIA/HCIP/HCIE的报名官网
  • Python中bs4库的详细介绍
  • 多个Echart遍历生成 / 词图云
  • 相机内外参知识
  • 【Rust自学】4.4. 引用与借用
  • Golang学习历程【第三篇 基本数据类型类型转换】
  • idea部署maven项目步骤(图+文)
  • 深入理解 JVM 垃圾回收机制
  • Neo4j【环境部署 02】图形数据库Neo4j在Linux系统ARM架构下的安装使用