当前位置: 首页 > article >正文

深度学习笔记10-数据增强(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

前言

在深度学习中,数据增强(Data Augmentation)是一种通过对现有数据进行各种转换和变换,从而生成更多训练样本的方法。在计算机视觉中,常见的数据增强方法包括随机裁剪、旋转、翻转、缩放、平移、亮度调整、对比度调整、添加噪声等。其主要目的是通过增加数据量和多样性,帮助模型学习到更加泛化的特征,提高模型的鲁棒性,并减少过拟合现象。

一、前期工作

1.加载数据

import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.keras import layers
data_dir   = "./T10/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

84a9b85636ca420c9039ed8424ba43e9.png 

class_names = train_ds.class_names
print(class_names)

 2.创建测试集

由于原始数据集不包含测试集,因此需要创建一个。使用tf.data.experimental.cardinality确定验证集中有多少批次的数据,然后将其中的20%移至测试集。

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

809130de99d247b69028d11796a0918e.png

3.配置数据集

AUTOTUNE=tf.data.AUTOTUNE
def preprocess_image(image,label):
    return(image/255.0,label)
#归一化处理
train_ds=train_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
val_ds=val_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
test_ds=test_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)

train_ds=train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds=train_ds.cache().prefetch(buffer_size=AUTOTUNE)

4.可视化 

plt.figure(figsize=(15,10))
for image,labels in train_ds.take(1):
    for i in range(8):
        ax=plt.subplot(5,8,i+1)
        plt.imshow(image[i])
        plt.title(class_name[labels[i]])

        plt.axis('off')

c9c79377867d4f52b24fff8a749f5899.png二、数据增强

我们可以使用 tf.keras.layers.RandomFliptf.keras.layers.RandomRotation 进行数据增强

  • tf.keras.layers.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),#添加一个随机翻转层,该层以一定的概率对输入图像进行水平和垂直翻转
  tf.keras.layers.RandomRotation(0.2),#添加一个随机旋转层,该层以一定的概率对输入图像进行旋转,旋转角度在-0.2到0.2弧度之间
])
# Add the image to a batch.
image = tf.expand_dims(images[i], 0) #0表示在数组的最前面增加一个维度,这样原本的单个图像就变成了一个批次。
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

cdffb05a1552409881f6ad01bbdf020c.png

三、增强方式

1.方式一:嵌入model 

注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])

2.方式二:在dataset中进行 

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

四、训练模型

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
epochs=20
history = model.fit(train_ds , validation_data=val_ds , epochs=epochs)

69d0be5fe5d342ccaa6bdbcccf490e51.png

acc,loss=model.evaluate(test_ds)
print("Accuracy", acc)

9e9845428e5f4bc8b846772f2119d989.png

五、自定义增强函数

随机亮度、对比度、色度、饱和度的设置

import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(0, 9), 0)
    # 随机改变图像对比度
    stateless_random_contrast = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    # 随机改变图像的亮度
    stateless_random_brightness = tf.image.stateless_random_brightness(stateless_random_contrast, max_delta=0.3,seed=seed)
    # 随机改变图像的色度
    stateless_random_hue = tf.image.stateless_random_hue(stateless_random_brightness, max_delta=0.3,seed=seed)
    # 随机改变图像的饱和度
    stateless_random_saturation = tf.image.stateless_random_saturation(stateless_random_hue, lower=0.1, upper=1.0, seed=seed)
    
    return stateless_random_saturation
image = tf.expand_dims(images[7]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

e24f1a77fb954e31a0511df014218e38.png

六、总结

1.基础数据增强方式

  • 几何数据增强:包括旋转、平移、错切等操作,这些技术通过改变图像中像素值的位置来增强数据。
  • 非几何数据增强:侧重于图像的视觉外观,如噪声注入、翻转、裁剪、调整大小和色彩空间操作。
  • 翻转:水平或垂直翻转图像,是一种常用的数据增强技术。
  • 裁剪和调整大小:通过随机裁剪或中心裁剪作为数据增强,减小图像大小后再调整回原始大小。
  • 注入噪声:向图像中注入噪声,帮助模型学习稳健的特征。
  • 光度增强:通过改变RGB通道值来控制亮度,避免模型偏向特定光照条件。
  • 扰动:随机改变图像的亮度、对比度、饱和度和色调。
  • 核过滤:使用核或高斯模糊过滤器来锐化或模糊图像。

2.tf.data.experimental.cardinality

tf.data.experimental.cardinality是 TensorFlow 的一个函数,用于估计一个tf.data.Dataset数据集的元素数量。这个函数返回一个整数或None。如果返回整数,它代表数据集中元素的估计数量;如果返回None,则表示数据集的元素数量未知或无法确定。ad5cc8b5566e499390b8f5b8ebaac012.png

3.翻转和旋转

  • tf.keras.layers.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.RandomRotation:随机旋转每个图像

4.有状态随机变换和无状态随机变换

290057707350442aa4bcf81d8e2bf29b.png

 


http://www.kler.cn/a/466236.html

相关文章:

  • C语言 递归编程练习
  • Eplan 项目结构(高层代号、安装地点、位置代号)
  • OpenNJet v3.2.0正式发布!
  • Dubbo扩展点加载机制
  • 等保测评和密评的相关性和区别
  • dockerignore文件怎么写
  • 在Vue3项目中使用svg-sprite-loader
  • Gitee 的基本用法
  • 查看打开的端口
  • 【JavaWeb后端学习笔记】MySQL的数据控制语言(Data Control Language,DCL)
  • 多线程访问FFmpegFrameGrabber.start方法阻塞问题
  • SkyWalking概述
  • 谷歌浏览器的高级安全设置使用方法
  • 整数拼接(哈希表 枚举)
  • docker基本概念,docker镜像管理,docker命令
  • zookeeper+kafka
  • 深入剖析MySQL数据库架构:核心组件、存储引擎与优化策略(四)
  • matlab系列专栏-matlab概述
  • xdoj 出现次数最多的数
  • WPF 数据绑定中的通知机制及其性能考虑
  • Android多渠道打包【友盟方式详细讲解版】
  • 《Opencv》基础操作详解(4)
  • python实现,outlook每接收一封邮件运行检查逻辑,然后发送一封邮件给指定邮箱
  • 单片机按键扫描程序,可以单击、双击、长按,使用状态机,无延时,不阻塞。
  • JavaScript中的“==”和“===”有什么区别
  • Docker 容器技术与 K8s