TensorFlow学习:使用官方模型进行图像分类并对模型进行微调
本文是对文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调
的补充说明。因为版本兼容的原因,原文有多处代码无法成功运行。这里把调整后的两处完整代码贴了出来,同时附上对应的模型文件(里面的文件或目录和Python文件放在相同目录下),以作对比。
运行环境:Mac 14.2、Python 3.12.2、tensorflow 2.17.0
使用成熟模型的完整代码:
# 导入tensorflow 和科学计算库
import tensorflow as tf
import numpy as np
# tensorflow-hub是一个TensorFlow库的扩展,它提供了一个简单的接口,用于重用已经训练好的机器学习模型的部分
import tensorflow_hub as hub
# 字体属性
from matplotlib.font_manager import FontProperties
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt
# 用于加载json文件
import json
import tf_keras
import ssl
import certifi
# 导入模型
# 不能直接加载模型文件,需要加载器目录
# 加载mobilenet_v2模型,这里要加载文件夹不要直接加载pb文件
# 模型如何加载要看文档,原来使用tf.keras.models.load_model加载一直失败
model = tf_keras.Sequential([
hub.KerasLayer('mobilenet-v2-classification')
])
# 假设输入为224x224 RGB图像
#input_shape = (224, 224, 3)
#input_layer = tf.keras.Input(shape=(input_shape))
#hub_layer = hub.KerasLayer("mobilenet_v2", trainable=True)
#x = hub_layer(input_layer)
#output_layer = tf.keras.layers.Dense(units, activation='activation_type')(x)
#
#model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
#model = tf.keras.applications.mobilenet_v2.MobileNetV2()
print("模型信息:",model)
# 预处理输入数据
# 1、mobilenet需要的图片尺寸是 224 * 224
image = tf.keras.preprocessing.image.load_img('pics/dog.jpg',target_size=(224,224))
# 设置SSL上下文
#image =tf.keras.utils.get_file('bird.jpg','https://scpic.chinaz.net/files/default/imgs/2023-08-29/7dc085b6d3291303.jpg')
# 2、将图片转为数组,既是只有一张图片
image = tf.keras.preprocessing.image.img_to_array(image)
# 3、扩展数组维度,使其符合模型的输入
image = np.expand_dims(image, axis=0)
# 4、使用mobilenet_v2提供的预处理函数对图像处理,包括图像归一化、颜色通道顺序调整、像素值标准化等操作
image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
# 预测
predictions = model.predict(image)
# 获取最高概率对应的类别索引
predicted_index = np.argmax(predictions)
# 概率值
confidence = np.max(predictions)
print("索引和概率值是:",predicted_index,confidence)
# 初始化一个空列表来存储文件的行
#labels_dict = []
# 加载映射文件
with open('mobilenet_v2/ImageNetLabels.txt','r') as f:
# labels_dict = json.load(f)
labels_dict = f.readlines()
# 类别的索引是字符串,这里要简单处理一下,这里-1是因为官方提供的多了一个0(背景),我找到的标签没有这个,因此要-1
class_name = labels_dict[predicted_index]
print(class_name)
# 可视化显示
font = FontProperties()
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(image[0]) # 显示图片
plt.xlabel(class_name,fontproperties=font)
plt.show() # 显示图形窗口
对模型进行微调的完整代码
# 导入tensorflow 和科学计算库
import tensorflow as tf
import numpy as np
# tensorflow-hub是一个TensorFlow库的扩展,它提供了一个简单的接口,用于重用已经训练好的机器学习模型的部分
import tensorflow_hub as hub
# 字体属性
from matplotlib.font_manager import FontProperties
# matplotlib是用于绘制图表和可视化数据的库
import matplotlib.pylab as plt
import datetime
import tf_keras
# 导入模型
# 不能直接加载模型文件,需要加载器目录
model = tf_keras.Sequential([
hub.KerasLayer('mobilenet-v2-classification')
])
# 32张图片为一个批次,尺寸设置为224*224
batch_size = 32
img_height = 224
img_width = 224
# 加载图像数据集,并将其分割为训练集和验证集,验证集比例为20%
train_ds = tf.keras.utils.image_dataset_from_directory(
'flower_photos', # 目录
validation_split=0.2, # 验证集占20%
subset="training", # 将数据集划分为训练集
seed= 123, # 随机种子,用于数据集随机划分
image_size= (img_width,img_height) , # 调整图像大小
batch_size= batch_size # 每个批次中包含的图像数量
)
# 验证集
val_ds = tf.keras.utils.image_dataset_from_directory(
'flower_photos', # 目录
validation_split=0.2, # 验证集占20%
subset="validation", # 将数据集划分为验证集
seed= 123, # 随机种子,用于数据集随机划分
image_size= (img_width,img_height) , # 调整图像大小
batch_size= batch_size # 每个批次中包含的图像数量
)
# 花卉种类
class_names = np.array(train_ds.class_names)
print("花卉种类:",class_names)
# 归一化
normalization_layer = tf.keras.layers.Rescaling(1./255) # 创建了一个Rescaling层,将像素值缩放到0到1之间 。 1./255是 1/255保留小数,差点没看懂
train_ds = train_ds.map(lambda x,y:(normalization_layer(x),y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
# 使用缓冲预取,避免产生I/O阻塞
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 验证数据是否成功加载和处理
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
# 对一批图片运行分类器,进行预测
result_batch = model.predict(train_ds)
# 加载映射文件,这里我将其下载到了本地
imagenet_labels = np.array(open('mobilenet-v2-feature-vector/ImageNetLabels.txt').read().splitlines())
# 在给定的张量中找到沿指定轴的最大值的索引
predict_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
print("预测类别:",predict_class_names)
# 绘制出预测与图片
# plt.figure(figsize=(10,9))
# plt.subplots_adjust(hspace=0.5)
# for n in range(30):
# plt.subplot(6,5,n+1)
# plt.imshow(image_batch[n])
# plt.title(predict_class_names[n])
# plt.axis('off')
# _ = plt.suptitle("ImageNet predictions")
# plt.show()
# 加载特征提取器
feature_extractor_layer = hub.KerasLayer(
'mobilenet-v2-feature-vector', # 预训练模型
input_shape=(224,224,3), # 指定图像输入的高度、宽度和通道数
trainable=False #训练过程中不更新特征提取器的权重
)
# 特征提取器为每个图像返回一个 1280 长的向量(在此示例中,图像批大小仍为 32)
feature_batch = feature_extractor_layer(image_batch)
print("特征批次形状:",feature_batch.shape)
# 附加分类头
new_model = tf_keras.Sequential([
feature_extractor_layer,
tf_keras.layers.Dense(len(class_names),activation='softmax') # 指定输出分类,这里的花是5类
])
# 训练模型
new_model.compile(
optimizer=tf_keras.optimizers.Adam(), # 使用Adam优化器作为优化算法
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 使用SparseCategoricalCrossentropy作为损失函数
metrics=['acc'] # 使用准确率作为评估指标
)
# 训练日志
log_dir = "logs/fit/" + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
# 用于在训练过程中收集模型指标和摘要数据,并将其写入TensorBoard日志文件中
tensorboard_callback = tf_keras.callbacks.TensorBoard(
log_dir= log_dir,
histogram_freq=1
)
# 开始训练,暂时只训练10轮。history记录了训练过程中的各项指标,便于后续分析和可视化
history = new_model.fit(
train_ds, # 训练数据集
validation_data=val_ds, # 验证数据集,用于在训练过程中监控模型的性能
epochs=10, # 训练的总轮次
callbacks=tensorboard_callback # 回调函数,用于在训练过程中执行特定操作,比如记录日志
)
# 简单预测
# predicted_batch = new_model.predict(image_batch)
# predicted_id = tf.math.argmax(predicted_batch, axis=-1)
# predicted_label_batch = class_names[predicted_id]
# print("花卉种类:",predicted_label_batch)
# plt.figure(figsize=(10,9))
# plt.subplots_adjust(hspace=0.5)
# for n in range(30):
# plt.subplot(6,5,n+1)
# plt.imshow(image_batch[n])
# plt.title(predicted_label_batch[n].title())
# plt.axis('off')
# _ = plt.suptitle("Model predictions")
# plt.show()
# 导出训练好的模型
export_path = 'tmp/saved_models/flower_model'
new_model.save(export_path)