当前位置: 首页 > article >正文

CNN应用Keras Tuner寻找最佳Hidden Layers层数和神经元数量

介绍: 

Keras Tuner是一种用于优化Keras模型超参数的开源Python库。它允许您通过自动化搜索算法来寻找最佳的超参数组合,以提高模型的性能。Keras Tuner提供了一系列内置的超参数搜索算法,如随机搜索、网格搜索、贝叶斯优化等。它还支持自定义搜索空间和搜索算法。通过使用Keras Tuner,您可以更轻松地优化模型的性能,节省调参的时间和精力。

数据: 

from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

'''
Label   Description
0   T-shirt/top
1   Trouser
2   Pullover
3   Dress
4   Coat
5   Sandal
6   Shirt
7   Sneaker
8   Bag
9   Ankle boot
'''

import matplotlib.pyplot as plt
%matplotlib inline

print(y_test[0])
plt.imshow(x_test[0], cmap="gray") 
#each having 1 channel (grayscale, it would have been 3 in the case of color, 1 each for Red, Green and Blue)

建模:

x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1) 

from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Activation
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

model = keras.models.Sequential()

model.add(Conv2D(32, (3, 3), input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())  # this converts our 3D feature maps to 1D feature vectors

model.add(Dense(10))
model.add(Activation("softmax"))

model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=64, epochs=1, validation_data = (x_test, y_test))

 

 Keras Tuner:

from kerastuner.tuners import RandomSearch
from kerastuner.engine.hyperparameters import HyperParameters

def build_model(hp):  # random search passes this hyperparameter() object 
    model = keras.models.Sequential()

    model.add(Conv2D(hp.Int('input_units',
                                min_value=32,
                                max_value=256,
                                step=32), (3, 3), input_shape=x_train.shape[1:]))

    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    for i in range(hp.Int('n_layers', 1, 4)):  # adding variation of layers.
        model.add(Conv2D(hp.Int(f'conv_{i}_units',
                                min_value=32,
                                max_value=256,
                                step=32), (3, 3)))
        model.add(Activation('relu'))

    model.add(Flatten()) 
    model.add(Dense(10))
    model.add(Activation("softmax"))

    model.compile(optimizer="adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])

    return model

tuner = RandomSearch(
    build_model,
    objective='val_accuracy',
    max_trials=1,  # how many model variations to test?
    executions_per_trial=1,  # how many trials per variation? (same model could perform differently)
    directory='Lesson56',
    project_name='Optimise')


tuner.search(x=x_train,
             y=y_train,
             verbose=1, # just slapping this here bc jupyter notebook. The console out was getting messy.
             epochs=1,
             batch_size=64,
             #callbacks=[tensorboard],  # if you have callbacks like tensorboard, they go here.
             validation_data=(x_test, y_test))

tuner.results_summary()

 


http://www.kler.cn/news/234025.html

相关文章:

  • vue3跨组件(多组件)通信:事件总线【Event Bus】
  • 修改GI文件的权限
  • 双活工作关于nacos注册中心的数据迁移
  • C#系列-C#访问MongoDB+redis+kafka(7)
  • Avalonia学习(二十三)-大屏
  • 方格定位2_题解
  • Qt安装配置教程windows版(包括:Qt5.8.0版本,Qt5.12,Qt5.14版本下载安装教程)(亲测可行)
  • STM32 FSMC (Flexible static memory controller) 灵活静态内存控制器介绍
  • Android java基础_类的继承
  • python如何用glob模块匹配路径
  • Lua Global环境
  • 时间序列预测——BiGRU模型
  • 应急响应-挖矿木马-常规处置方法
  • notepad++成功安装后默认显示英文怎么设置中文界面?
  • 突破编程_C++_面试(基础知识(10))
  • 学习笔记——ENM模拟
  • 微服务学习 | Spring Cloud 中使用 Sentinel 实现服务限流
  • 零基础学编程从入门到精通,系统化的编程视频教程上线,中文编程开发语言工具构件之缩放控制面板构件用法
  • Centos 7系统安装proftpd-1.3.8过程
  • 图像的旋转不变特性及应用
  • React18原理: Fiber架构下的单线程CPU调度策略
  • 代码随想录-背包问题
  • vue3 之 商城项目—详情页
  • 双非本科准备秋招(21.1)—— 力扣二叉搜索树
  • C++ //练习 5.4 说明下列例子的含义,如果存在问题,试着修改它。
  • 随机MM引流源码PHP开源版
  • RabbitMQ-1.介绍与安装
  • django中实现观察者模式
  • Python面试题7-12
  • Linux无交互自动安装miniconda3