当前位置: 首页 > article >正文

T6识别好莱坞明星

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

导入基础的包

from tensorflow       import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow        as tf
import numpy             as np

读取本地的好莱坞明星文件构建数据集。

data_dir = "./48-data/"

data_dir = pathlib.Path(data_dir)

打印文件的数量,一共1800张图片。

image_count = len(list(data_dir.glob('*/*.jpg')))

print("图片总数为:",image_count)

构建训练集

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="training",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

构建验证集

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    label_mode = "categorical",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

构建网络模型

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Dropout(0.5),  
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.AveragePooling2D((2, 2)),     
    layers.Dropout(0.5),  
    layers.Conv2D(128, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.5), 
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(len(class_names))               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

设置学习率,并且编译网络模型

initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=60,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

开始训练

轮次 100轮,保存最佳的模型参数。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 100

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy', 
                             min_delta=0.001,
                             patience=20, 
                             verbose=1)

开始训练


history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

画图训练集和测试集的 准确率和丢失率。

from PIL import Image
import numpy as np

img = Image.open("./48-data/Jennifer Lawrence/003_963a3627.jpg") 
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) 

predictions = model.predict(img_array) 
print("预测结果为:",class_names[np.argmax(predictions)])


http://www.kler.cn/a/396356.html

相关文章:

  • STM32+AI语音识别智能家居系统
  • ES6更新的内容中什么是proxy
  • Linux之vim全选,全部复制,全部删除
  • 【Java基础知识系列】之Java类的初始化顺序
  • Dubbo 3.2 源码导读
  • 31-Shard Allocation Awareness(机架感知)
  • maven手动上传jar到私服仓库:mvn deploy:deploy-file命令
  • linux rsync 同步拉取上传文件
  • 【SpringBoot】使用过滤器进行XSS防御
  • 在uniapp中使用canvas封装组件遇到的坑,数据被后面设备覆盖,导致数据和前面的设备一样
  • 编译原理(手绘)
  • 2024年【A特种设备相关管理(A4电梯)】新版试题及A特种设备相关管理(A4电梯)找解析
  • 【AlphaFold3】开源本地的安装及使用
  • [Mysql] Mysql的多表查询----多表关系(下)
  • 精华帖分享|浅谈金融时间序列分析与股价随机游走
  • Maven配置元素详解
  • MATLAB中的绘图技巧
  • 高并发下如何保障系统的正确性?性能与一致性博弈的技术探索
  • ⾃动化运维利器 Ansible-Jinja2
  • 【MySQL】索引原理及操作
  • 如何用Python爬虫精准获取商品历史价格信息及API数据
  • sql server into #t2 到临时表的几种用法
  • 8 软件项目管理
  • JavaScript 自动化软件:AutoX.js
  • 入门车载以太网(4) -- 传输层(TCP\UDP)
  • django入门【05】模型介绍(二)——字段选项