当前位置: 首页 > article >正文

Optuna深度学习自动调参工具使用简明教程

Optuna 是一个灵活的超参数优化框架,广泛用于机器学习和深度学习模型的超参数调优。
自动调参工具可是太爽了,先给大家看一个训练例子的dashboard图
在这里插入图片描述
请添加图片描述

环境配置简易直接搞一个docker

$ docker run --rm -v $(pwd):/prj -w /prj optuna/optuna:py3.7-dev /bin/bash

我们分析以下官方提供的示例torch例子,使用 Optuna 进行超参数优化的示例,旨在优化使用 PyTorch 训练的多层感知器(MLP)在 FashionMNIST 数据集上的验证准确率。下面是代码逐部分的详细讲解:

1. 导入必要的库

import os
import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
  • os:用于与操作系统交互,获取当前工作目录。
  • optuna:用于超参数优化。
  • torchtorch.nntorch.optim:PyTorch 的主要库,用于构建和训练模型。
  • torchvision:提供用于计算机视觉任务的数据集和转换工具。

2. 定义常量和设备

DEVICE = torch.device("cpu")
BATCHSIZE = 128
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
  • DEVICE:设置为 CPU,可以根据需要修改为 GPU。
  • BATCHSIZE:每个训练批次的样本数。
  • CLASSES:分类的数量(FashionMNIST 有 10 个类别)。
  • DIR:当前工作目录,用于保存数据集。
  • EPOCHS:训练的轮数。
  • N_TRAIN_EXAMPLESN_VALID_EXAMPLES:训练和验证中使用的样本数限制。

3. 定义模型结构

def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features
    layers.append(nn.Linear(in_features, CLASSES))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)
  • define_model:根据试验中的建议参数创建模型。
  • n_layers:建议的层数(1 到 3)。
  • out_features:每层的神经元数量(4 到 128)。
  • Dropout:防止过拟合,通过随机丢弃一定比例的神经元。

4. 获取数据集

def get_mnist():
    train_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(DIR, train=False, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    return train_loader, valid_loader
  • get_mnist:加载 FashionMNIST 数据集,并返回训练和验证的 DataLoader。
  • 使用 transforms.ToTensor() 将图像转换为张量。

5. 定义目标函数

def objective(trial):
    model = define_model(trial).to(DEVICE)

    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

    train_loader, valid_loader = get_mnist()

    for epoch in range(EPOCHS):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
                output = model(data)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

        trial.report(accuracy, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return accuracy
  • objective:优化的目标函数,返回模型的验证准确率。
  • 优化器:根据试验建议的优化器类型(Adam、RMSprop 或 SGD)和学习率。
  • 训练和验证:训练模型并在验证集上计算准确率。
  • 剪枝:使用 trial.should_prune() 检查是否提前停止试验。

6. 主程序

if __name__ == "__main__":
    study = optuna.create_study(direction="maximize",storage="sqlite:///db.sqlite3")  
    study.optimize(objective, n_trials=100, timeout=600)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))![请添加图片描述](https://i-blog.csdnimg.cn/direct/85378d9b615948bebd3d90ffc8096422.png)

  • 创建一个优化研究(study),目标是最大化准确率。
  • 运行 100 次试验,每次最多花费 600 秒。
  • 统计并输出完成和被剪枝的试验数量。
  • 输出最佳试验的结果及其超参数。
  • 其中我们将训练的过程保存在了db.sqlite3这个文件里(这个自己加一下, 官方例子米有),官方提供了可解析的工具:optuna-dashboard

我们运行例子程序就启动训练, 输出的log如下:

请添加图片描述

可以看到log里有pruned的尝试,也有finished的尝试,optuna会自动终止认为不收敛的超参数配置的训练

7. dashboard 分析训练过程

先安装一下工具:

$ pip install optuna-dashboard

导入你的训练log文件

$ optuna-dashboard sqlite:///db.sqlite3
Listening on http://localhost:8080/
Hit Ctrl-C to quit.

打开你的浏览器就可以查看训练的log了,如博客最上方的图,可以看到每一次迭代的超参数,以及精度的变化

总结

这个示例展示了如何使用 Optuna 进行超参数优化,通过对多层感知器的层数、每层的单元数、Dropout 比例和优化器等超参数进行调优,来提高模型在 FashionMNIST 数据集上的验证准确率。通过使用剪枝功能,可以在训练过程中提前终止表现不佳的试验,从而节省计算资源。

参考

https://github.com/optuna/optuna-examples/tree/main
https://github.com/optuna/optuna-dashboard?tab=readme-ov-file


http://www.kler.cn/a/372323.html

相关文章:

  • 【Python】selenium结合js模拟鼠标点击、拦截弹窗、鼠标悬停方法汇总(使用 execute_script 执行点击的方法)
  • 单片机常用外设开发流程(1)(IMX6ULL为例)
  • flux中的缓存
  • 从0入门自主空中机器人-4-【PX4与Gazebo入门】
  • LeetCode算法题——有序数组的平方
  • 《Vue进阶教程》第三十一课:ref的初步实现
  • Java 文件路径一口气讲完!(* ̄3 ̄)╭
  • 牛客网刷题(3)(Java的几种常用包)
  • 实操|如何优雅的实现RAG与GraphRAG应用中的知识文档增量更新?
  • Webserver(1.8)操作函数
  • CSS常见适配布局方式
  • 逆变器竞品分析--倍思500W方案【2024/10/30】
  • Android 快捷方式
  • 海外共享奶牛牧场投资源码-理财金融源码-基金源码-共享经济源码
  • 《掌握 Java:从基础到高级概念的综合指南》(3/15)
  • 多GPU训练大语言模型,DDP, ZeRO 和 FSDP
  • 【再谈设计模式】单例模式~唯一性的守护者
  • Dockerfile制作Oracle19c镜像
  • xpath爬虫
  • 多线程显示 CSV 2 PNG 倒计时循环播放
  • 低功耗模组学习指南:从入门到精通通过MQTT连接实现远程控制
  • 如何在不同设备上轻松下载Facebook应用:全面指南
  • AI助力医疗数据自动化:诊断报告识别与管理
  • TCP全连接队列与 tcpdump 抓包
  • vue点击菜单,出现2个相同tab,啥原因
  • 代码备份管理 —— Git实用操作