当前位置: 首页 > article >正文

卷积神经网络(CNN)模型 CIFAR-10 数据集 例子

使用 TensorFlow 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。
使用自动编码器作为特征提取器,先通过自动编码器对图像数据进行降维,将图像从高维映射到低维特征空间,然后将提取的特征传入到 CNN 进行分类。
对比在不使用自动编码器特征提取的情况下,直接使用 CNN 进行分类的模型性能。

# 导入必要的库
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten

# 加载CIFAR - 10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# 预处理数据
x_train = x_train / 255.0
x_test = x_test / 255.0

# 构建自动编码器
input_img = Input(shape=(32, 32, 3))
# 编码器
encoded = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
encoded = MaxPooling2D((2, 2), padding='same')(encoded)
# 解码器
decoded = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
decoded = tf.keras.layers.UpSampling2D((2, 2))(decoded)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(decoded)
autoencoder = Model(input_img, decoded)

# 编译自动编码器
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 训练自动编码器
autoencoder.fit(x_train, x_train,
                epochs=10,
                batch_size=128,
                validation_data=(x_test, x_test))

# 获取编码器部分
encoder = Model(input_img, encoded)

# 使用编码器提取特征
x_train_encoded = encoder.predict(x_train)
x_test_encoded = encoder.predict(x_test)

# 构建CNN分类器
input_features = Input(shape=x_train_encoded.shape[1:])
flatten = Flatten()(input_features)
dense1 = Dense(128, activation='relu')(flatten)
output = Dense(10, activation='softmax')(dense1)
classifier = Model(input_features, output)

# 编译CNN分类器
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 训练CNN分类器
classifier.fit(x_train_encoded, y_train,
               epochs=10,
               batch_size=128,
               validation_data=(x_test_encoded, y_test))

# 预测第1000个数据的类别(假设x_test是测试数据)
prediction = classifier.predict(x_test_encoded[999:1000])
predicted_class = tf.argmax(prediction, axis=1).numpy()[0]

Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.6075 - val_loss: 0.5594
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5575 - val_loss: 0.5567
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5555 - val_loss: 0.5559
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5547 - val_loss: 0.5551
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5544 - val_loss: 0.5547
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5538 - val_loss: 0.5534
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5520 - val_loss: 0.5527
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5519 - val_loss: 0.5525
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5517 - val_loss: 0.5523
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5506 - val_loss: 0.5522
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 821us/step
Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.2917 - loss: 1.9827 - val_accuracy: 0.4199 - val_loss: 1.6372
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4255 - loss: 1.6262 - val_accuracy: 0.4405 - val_loss: 1.5675
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4576 - loss: 1.5361 - val_accuracy: 0.4749 - val_loss: 1.5088
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4828 - loss: 1.4765 - val_accuracy: 0.5027 - val_loss: 1.4356
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4963 - loss: 1.4287 - val_accuracy: 0.5055 - val_loss: 1.4277
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5055 - loss: 1.3981 - val_accuracy: 0.5067 - val_loss: 1.4073
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5203 - loss: 1.3623 - val_accuracy: 0.5194 - val_loss: 1.3617
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5338 - loss: 1.3217 - val_accuracy: 0.5246 - val_loss: 1.3555
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5352 - loss: 1.3199 - val_accuracy: 0.5352 - val_loss: 1.3252
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5379 - loss: 1.3143 - val_accuracy: 0.5144 - val_loss: 1.3900
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step

predicted_class

8

以下是使用 Python 语言结合 TensorFlow 库构建卷积神经网络(CNN)对 CIFAR-10 数据集进行图像分类,并获取第 1000 个数据预测类别的示例代码:

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

# 加载CIFAR-10数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# 归一化像素值到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0

# 构建简单的CNN模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))  # 输出层,对应10个类别(CIFAR-10有10类)

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=10, 
          validation_data=(test_images, test_labels))

# 对测试集进行预测
predictions = model.predict(test_images)
# 获取预测的类别(取概率最大的类别索引作为预测类别)
predicted_classes = np.argmax(predictions, axis=1)

# 获取第1000个数据的预测类别
print(predicted_classes[999])

Epoch 1/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 6ms/step - accuracy: 0.3424 - loss: 1.7848 - val_accuracy: 0.5325 - val_loss: 1.2917
Epoch 2/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.5668 - loss: 1.2166 - val_accuracy: 0.6098 - val_loss: 1.1129
Epoch 3/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6391 - loss: 1.0275 - val_accuracy: 0.6399 - val_loss: 1.0059
Epoch 4/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6676 - loss: 0.9352 - val_accuracy: 0.6690 - val_loss: 0.9239
Epoch 5/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7065 - loss: 0.8364 - val_accuracy: 0.6935 - val_loss: 0.8983
Epoch 6/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7272 - loss: 0.7781 - val_accuracy: 0.6914 - val_loss: 0.8845
Epoch 7/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7507 - loss: 0.7195 - val_accuracy: 0.6843 - val_loss: 0.9197
Epoch 8/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7609 - loss: 0.6772 - val_accuracy: 0.7024 - val_loss: 0.8741
Epoch 9/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7773 - loss: 0.6337 - val_accuracy: 0.7055 - val_loss: 0.8704
Epoch 10/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7888 - loss: 0.5993 - val_accuracy: 0.7074 - val_loss: 0.8714
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
8


http://www.kler.cn/a/456894.html

相关文章:

  • 动手做计算机网络仿真实验入门学习
  • 本地小主机安装HomeAssistant开源智能家居平台打造个人AI管家
  • 使用JMeter对Linux生产服务器进行压力测试
  • C语言 练习
  • Kubernetes Gateway API-2-跨命名空间路由
  • 2024年个人总结
  • 学习,指针和FLASH
  • 02-18.python入门基础一基础算法
  • [江科大STM32] 第五集STM32工程模板——笔记
  • rk356x 下 qt 程序 hdmi不显示鼠标图标
  • 数值分析雨课堂章节测试
  • Java重要面试名词整理(十一):网络编程
  • 渗透测试常用术语总结
  • 深入了解JSON-LD:语义化网络数据的桥梁
  • v-if 和 v-show 的区别
  • anythingllm无法获取ollama模型
  • 在线学习平台-项目技术点-后台
  • 【计组】复习总结期末
  • 8.Java内置排序算法
  • mybatisplu设置自动填充
  • Chrome被360导航篡改了怎么改回来?
  • 【若依】RuoYi二开 -< 报错 >:com.ruoyi.common.exception.ServiceException: 获取用户信息异常
  • 寄存器控制LED灯亮
  • 前后端分离(前端删除数据库数据)
  • Linux top指令
  • Hadoop的生态系统所包含的组件