当前位置: 首页 > article >正文

Tensorflow2.0:CNN、ResNet实现MNIST分类识别

以下仅是个人的学习笔记 ,内容可能是错误

CNN: 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 导入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0

# 构建模型
model = keras.Sequential([
    layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(10, activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

ResNet18: 

import tensorflow as tf
from keras import layers, models, datasets
import os

# 定义gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编号
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)  # 动态申请显存
    except RuntimeError as e:
        print(e)

# 加载数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# 数据预处理
train_images, test_images = train_images / 255.0, test_images / 255.0


# 搭建残差模块
def resnet_block(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu'):
    x = layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    if activation:
        x = layers.Activation(activation)(x)
    return x


# 定义resnet
def resnet18():
    inputs = layers.Input(shape=(32, 32, 3))
    num_filters = 64
    t = layers.BatchNormalization()(inputs)
    t = resnet_block(t, num_filters=num_filters)
    for i in range(2):
        t = resnet_block(t, num_filters=num_filters, activation=None)
        t = layers.Add()([t, layers.Activation('relu')(t)])
    t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)
    t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])
    num_filters *= 2
    for i in range(2):
        t = resnet_block(t, num_filters=num_filters, activation=None)
        t = layers.Add()([t, layers.Activation('relu')(t)])
    t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)
    t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])
    num_filters *= 2
    for i in range(2):
        t = resnet_block(t, num_filters=num_filters, activation=None)
        t = layers.Add()([t, layers.Activation('relu')(t)])
    t = layers.AveragePooling2D()(t)
    outputs = layers.Dense(10, activation='softmax')(layers.Flatten()(t))
    model = models.Model(inputs, outputs)
    return model


# 定义模型
model = resnet18()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练 CPU
# history = model.fit(train_images, train_labels, epochs=10,
#                     validation_data=(test_images, test_labels))

with tf.device('GPU:0'):  # 指定使用GPU
    history = model.fit(train_images, train_labels, epochs=10,
                        validation_data=(test_images, test_labels))

 


http://www.kler.cn/news/134415.html

相关文章:

  • 求二叉树的高度(可运行)
  • buildadmin+tp8表格操作(3)----表头上方按钮绑定事件处理,实现功能(选中或取消指定行)
  • 互联网摸鱼日报(2023-11-20)
  • wpf devexpress post 更改数据库
  • kafka分布式安装部署
  • 【微信小程序篇】- 组件
  • 算法设计与分析复习--贪心(一)
  • 特效!视频里的特效在哪制作——Adobe After Effects
  • java智慧校园信息管理系统源码带微信小程序
  • 【wp】2023第七届HECTF信息安全挑战赛 Web
  • 什么是Selenium?如何使用Selenium进行自动化测试?
  • 初刷leetcode题目(4)——数据结构与算法
  • C# Array和ArrayList有什么区别
  • WPF拖拽相关的类
  • 详解Java设计模式之职责链模式
  • S7-1200PLC 作为MODBUSTCP服务器通信(多客户端访问)
  • Web安全研究(五)
  • python中Thread实现多线程任务
  • iTerm2+oh-my-zsh搭个Mac电脑上好用好看终端
  • Zotero在word中插入参考文献
  • 队列和微服务的异步通信
  • Python选择排序和冒泡排序算法
  • linux基础:4:gdb的使用
  • 保姆级 | Nginx编译安装
  • golang学习笔记——条件表达式
  • 【Dubbo】Dubbo负载均衡实现解析
  • nodejs微信小程序-实验室上机管理系统的设计与实现-安卓-python-PHP-计算机毕业设计
  • 2023数维杯国际赛数学建模竞赛选题建议及B题思路讲解
  • Linux本地docker一键部署traefik+内网穿透工具实现远程访问Web UI管理界面
  • OpenAI 地震!首席执行官被解雇,背后的原因是?