当前位置: 首页 > article >正文

TensorFlow2从磁盘读取图片数据集的示例(tf.data.Dataset.list_files)

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.resnet import ResNet50
from pathlib import Path
import numpy as np

#数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = Path(os.path.join(base_dir,'train'))
file_pattern = os.path.join(train_dir,'*/*.jpg')
image_count = len(list(train_dir.glob('*/*.jpg')))

list_ds = tf.data.Dataset.list_files(file_pattern,shuffle = False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5):
  print(f.numpy())
  
class_names = np.array(sorted([item.name for item in train_dir.glob('*') ]))
print(class_names)

val_size = int(image_count * 0.2)
train_data = list_ds.skip(val_size)
validation_data = list_ds.take(val_size)
print(tf.data.experimental.cardinality(train_data).numpy())
print(tf.data.experimental.cardinality(validation_data).numpy())


def get_label(file_path):
  parts = tf.strings.split(file_path, os.path.sep)
  one_hot = parts[-2] == class_names
  return tf.argmax(one_hot)

def decode_img(img):
  img = tf.io.decode_jpeg(img, channels=3)
  return tf.image.resize(img, [64, 64])

def process_path(file_path):
  label = get_label(file_path)
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

train_data = train_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
validation_data = validation_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)

for image, label in train_data.take(2):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(4)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

train_data = configure_for_performance(train_data)
validation_data = configure_for_performance(validation_data)


save_model_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model_resnet50_cats_and_dogs.h5', save_freq='epoch')

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(64, 64, 3))
base_model.trainable = True
    
model = tf.keras.models.Sequential([
    base_model,
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(l=0.01)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',optimizer = Adam(lr=1e-3),metrics = ['acc'])

history = model.fit(train_data.repeat(),steps_per_epoch=100,epochs=50,validation_data=validation_data.repeat(),validation_steps=50,verbose=1,callbacks = [save_model_cb])


http://www.kler.cn/a/106444.html

相关文章:

  • Redis实战案例(黑马点评)
  • 【Linux系统编程】第四十六弹---线程同步与生产消费模型深度解析
  • 【SpringBoot】公共字段自动填充
  • 【目标检测】用YOLOv8-Segment训练语义分割数据集(保姆级教学)
  • 零基础利用实战项目学会Pytorch
  • vue3 element el-table实现表格动态增加/删除/编辑表格行,带有校验规则
  • Python学习笔记第七十二天(Matplotlib imread)
  • 广西厂家直销建筑模板,工程用木工板,多层胶合板批发
  • 使用Intersection Observer API 检测元素是否出现在可视窗口
  • RK3568-pcie接口
  • LuatOS-SOC接口文档(air780E)--mcu - 封装mcu一些特殊操作
  • 如何在外网访问内网服务器数据库
  • 高通Quick Charge快速充电原理分析
  • Vue项目搭建及使用vue-cli创建项目、创建登录页面、与后台进行交互,以及安装和使用axios、qs和vue-axios
  • 在Linux中,可以使用以下命令来查看进程
  • tqdm 显示进度条模块
  • Echarts 实现 设备运行状态图(甘特图) 工业大数据展示
  • C++实现线程池
  • 软件工程第八周
  • 设计模式之中介模式
  • 2、基于pytorch lightning的fabric实现pytorch的多GPU训练和混合精度功能
  • Python学习笔记第六十九天(Matplotlib 直方图)
  • threejs(4)-纹理材质高级操作
  • 软件测试面试题
  • 面对6G时代 适合通信专业的 毕业设计题目
  • Unity Shader当用户靠近的时候会出现吃鸡一样的光墙