当前位置: 首页 > article >正文

AI学习指南深度学习篇-Adam的Python实践

AI学习指南深度学习篇-Adam的Python实践

在深度学习领域,优化算法是影响模型性能的关键因素之一。Adam(Adaptive Moment Estimation)是一种广泛使用的优化算法,因其在多种问题上均表现优异而被广泛使用。本文将深入探讨Adam优化器,并提供详细的代码示例,展示如何在Python的深度学习库(如TensorFlow和PyTorch)中实现Adam,进行模型训练以及调参过程。

引言

优化算法的选择会影响深度学习模型的收敛速度和最终性能。Adam算法不仅结合了动量(Momentum)的优点,还引入了自适应学习率,这使得其在许多任务中表现良好。本文将通过实际代码示例介绍Adam的实现和调参过程,让读者能够在自己的项目中有效应用这一算法。

Adam优化器概述

2.1 公式推导

Adam优化器的核心思想是计算梯度的动量以及梯度的平方动量,并利用这两个动量来调整学习率。Adam的更新公式如下:

  1. 初始化参数

    • ( m t = 0 ) ( m_t = 0 ) (mt=0)(一阶矩估计)
    • ( v t = 0 ) ( v_t = 0 ) (vt=0)(二阶矩估计)
    • ( t = 0 ) ( t = 0 ) (t=0)(时间步长)
    • ( β 1 , β 2 ) ( \beta_1, \beta_2 ) (β1,β2)(通常取值为0.9,0.999)
    • ( ϵ ) ( \epsilon ) (ϵ)(通常取小值以避免除零错误)
  2. 参数更新
    [ t = t + 1 ] [ t = t + 1 ] [t=t+1]
    [ m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t ] [ m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t ] [mt=β1mt1+(1β1)gt]
    [ v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 ] [ v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 ] [vt=β2vt1+(1β2)gt2]
    [ m ^ t = m t 1 − β 1 t ] [ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} ] [m^t=1β1tmt]
    [ v ^ t = v t 1 − β 2 t ] [ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} ] [v^t=1β2tvt]
    [ θ t = θ t − 1 − α v ^ t + ϵ ⋅ m ^ t ] [ \theta_{t} = \theta_{t-1} - \frac{\alpha}{\hat{v}_t + \epsilon} \cdot \hat{m}_t ] [θt=θt1v^t+ϵαm^t]

2.2 参数说明

  • 学习率 ( ( α ) ) ((\alpha)) ((α)):控制每次更新的步幅,通常初始值设为0.001。
  • ( β 1 ) (\beta_1) (β1) ( β 2 ) (\beta_2) (β2):分别控制一阶矩和二阶矩的衰减率。
  • ( ϵ ) (\epsilon) (ϵ):通常设为 ( 1 0 − 8 ) (10^{-8}) (108),避免在计算时出现除零错误。

在TensorFlow中使用Adam

3.1 环境准备

确保你的计算环境中安装了TensorFlow和其他必要的库:

pip install tensorflow numpy matplotlib

3.2 数据加载

我们将使用Keras提供的MNIST手写数字数据集作为示例:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

3.3 构建模型

我们将定义一个简单的神经网络模型:

def create_model():
    model = models.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28)))
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dropout(0.2))
    model.add(layers.Dense(10, activation="softmax"))
    return model

3.4 训练模型

使用Adam优化器训练模型:

model = create_model()

# 编译模型
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

# 训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

3.5 调整超参数

可以通过以下方式调整超参数,比如修改学习率或尝试不同的批大小:

from tensorflow.keras.optimizers import Adam

# 创建自定义Adam优化器
adam = Adam(learning_rate=0.001)

# 重新编译模型
model.compile(optimizer=adam, 
              loss="categorical_crossentropy", 
              metrics=["accuracy"])

# 重新训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.2)

在PyTorch中使用Adam

4.1 环境准备

确保你的计算环境中安装了PyTorch和其他必要的库:

pip install torch torchvision numpy matplotlib

4.2 数据加载

与TensorFlow类似,我们将使用同样的数据集:

import torch
from torchvision import datasets, transforms
from torch import nn, optim

# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 加载MNIST数据集
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

4.3 构建模型

PyTorch模型构建如下:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.shape[0], -1)  # 展平操作
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = SimpleNN()

4.4 训练模型

使用Adam优化器训练模型的示例如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 10
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        optimizer.zero_grad()  # 清空梯度
        output = model(images)  # 前向传播
        loss = criterion(output, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

4.5 调整超参数

在PyTorch中,你也可以像在TensorFlow中那样调整超参数,下面是修改学习率的例子:

# 创建自定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 重新训练模型
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

结论

Adam优化器因其良好的自适应性和快速的收敛能力,成为深度学习中最流行的优化算法之一。在TensorFlow和PyTorch等深度学习框架中,Adam均被用户广泛应用。本文详细介绍了在这两种框架中使用Adam优化器进行模型训练的完整流程,并展示了如何在训练过程中灵活调整超参数。希望这篇文章能帮助你更好地理解和应用Adam优化器。尽管TensorFlow和PyTorch有其独特之处,但选用合适的优化器对于模型的最终表现仍然至关重要。在实际应用中,建议尝试多种优化算法并进行超参数调整,以获得最佳的训练效果。

如果想了解更深入的Adam算法工作原理或其他优化算法的使用,请关注后续更新,继续学习更多的深度学习内容。


http://www.kler.cn/a/312234.html

相关文章:

  • 【Linux篇】面试——用户和组、文件类型、权限、进程
  • 羊城杯2020Easyphp
  • MySQL如何利用索引优化ORDER BY排序语句
  • 吴恩达机器学习笔记(3)
  • Redis安装(Windows环境)
  • 【大数据学习 | HBASE】hbase的读数据流程与hbase读取数据
  • 如何配置和使用自己的私有 Docker Registry
  • python的6种常用数据结构
  • 3.大语言模型LLM的公开资源(API、语料库、算法库)
  • Python中的树与图:构建复杂数据结构的艺术
  • 图论三元环(并查集的高级应用)
  • 天润融通创新功能,将无效会话转化为企业新商机
  • 青柠视频云——视频丢包(卡顿、花屏、绿屏)排查
  • Python 集合的魔法:解锁高效数据处理的秘密
  • 工厂模式(一):简单工厂模式
  • Web后端服务平台解析漏洞与修复、文件包含漏洞详解
  • 【Git原理与使用】多人协作与开发模型(2)
  • 杀死端口占用的进程
  • 常用函数式接口的使用
  • 3D GS 测试自己的数据
  • react 甘特图之旅
  • C语言 | Leetcode C语言题解之第405题数字转换为十六进制数
  • SpringBoot 数据库表结构文档生成
  • 深入Redis:核心的缓存
  • 【计算机网络 - 基础问题】每日 3 题(十四)
  • 百易云资产系统 house.save.php SQL注入