当前位置: 首页 > article >正文

【RL Base】强化学习核心算法:深度Q网络(DQN)算法

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】(50)---《强化学习核心算法:深度Q网络(DQN)算法》

强化学习核心算法:深度Q网络(DQN)算法

目录

1.深度Q网络(Deep Q-Network, DQN)算法详解

2.DQN基本原理

1. Q值函数

2. Bellman方程

3. 深度Q网络

3.DQN算法关键步骤

[Python] DQN算法实现

DQN算法在gym环境中实现

1.库导入

 2.定义Q网络

 3.定义智能体

 4.训练代码

 5.主函数 

[Notice]  说明

4.重要改进

5.DQN的强化学习背景应用


1.深度Q网络(Deep Q-Network, DQN)算法详解

        深度Q网络(DQN)是深度强化学习的核心算法之一,由Google DeepMind在2015年的论文《Playing Atari with Deep Reinforcement Learning》中提出。DQN通过结合深度学习和强化学习,利用神经网络近似Q值函数,在高维、连续状态空间的环境中表现出了强大的能力。


2.DQN基本原理

        DQN的目标是通过学习动作-价值函数Q(s, a),来找到最优策略,使得智能体在每个状态 s 下执行动作a能获得的未来累积奖励最大化。

1. Q值函数

Q值函数表示在状态( s )下执行动作 ( a )后能够获得的期望回报:

Q(s, a) = \mathbb{E}\left[ \sum_{t=0}^\infty \gamma^t r_t \mid s_0 = s, a_0 = a \right]

  • ( r_t ): 第 ( t ) 步的奖励。
  • ( \gamma ): 折扣因子,控制未来奖励的权重。
2. Bellman方程

Q值函数满足Bellman最优方程:

Q^(s, a) = r + \gamma \max_{a'} Q^(s', a')

  • ( s' ): 当前状态 ( s )执行动作( a ) 后转移到的下一个状态。
  • ( a' ): 下一步的可能动作。
3. 深度Q网络

        DQN使用神经网络来近似Q值函数( Q(s, a; \theta) ),其中( \theta )是网络参数。网络输入是状态 ( s ),输出是对应每个动作的Q值。


3.DQN算法关键步骤

3.1经验回放(Experience Replay)

        通过存储智能体的交互经验 ( (s, a, r, s') ) 在缓冲区中,并从中随机采样训练神经网络,打破时间相关性,提高数据样本效率。

3.2目标网络(Target Network)

        使用一个目标网络( Q(s, a; \theta^-))来计算目标值,而不是直接使用当前网络。这减少了训练不稳定性。

        每隔一定步数,将当前网络的参数( \theta )同步到目标网络( \theta^- )

3.3损失函数

        使用均方误差(MSE)作为损失函数:

L(\theta) = \mathbb{E}_{(s, a, r, s') \sim D}\left[\left(y - Q(s, a; \theta)\right)^2\right]

        其中目标值 ( y )为:y = r + \gamma \max_{a'} Q(s', a'; \theta^-)

3.4探索与利用(Exploration vs Exploitation)

        使用\epsilon-贪心策略,在动作选择上加入随机性


[Python] DQN算法实现

DQN算法伪代码

"""《DQN算法伪代码》
    时间:2024.11
    作者:不去幼儿园
"""
# 随机初始化 Q 网络的参数 θ
# θ 表示 Q 网络的权重,用于近似 Q 值函数
初始化 Q 网络参数 θ 随机

# 将目标 Q 网络的参数 θ^- 初始化为 Q 网络参数 θ 的值
# θ^- 是一个独立的目标网络,用于稳定 Q 值更新
初始化目标 Q 网络参数 θ^- = θ

# 初始化经验回放缓冲区 D
# D 是一个数据结构(例如 deque),存储智能体的交互经验 (状态, 动作, 奖励, 下一个状态)
初始化经验回放缓冲区 D

# 循环进行 M 个训练轮次(即 M 个 episode)
for episode = 1, M do
    # 初始化环境并获得初始状态 s
    # 这个状态将作为本轮 episode 的起点
    初始化状态 s

    # 循环处理每个时间步,T 是每轮 episode 的最大时间步数
    for t = 1, T do
        # 根据 ε-贪心策略选择动作
        # 以 ε 的概率随机选择动作(探索)
        # 否则,选择当前状态下 Q 值最大的动作(利用)
        以概率 ε 选择随机动作 a
        否则选择 a = argmax_a Q(s, a; θ)

        # 在环境中执行动作 a
        # 观察奖励 r 和下一个状态 s'
        执行动作 a,观察奖励 r 和下一个状态 s'

        # 将当前经验 (s, a, r, s') 存储到经验回放缓冲区 D 中
        # 经验回放缓冲区用于保存过去的交互记录
        将转换 (s, a, r, s') 存储到 D

        # 从经验回放缓冲区中随机抽取一个批次(minibatch)用于训练
        # 随机抽样打破时间相关性,提高样本效率
        从 D 中随机抽取一批 (s, a, r, s')

        # 使用目标 Q 网络 θ^- 计算目标 Q 值
        # 根据 Bellman 方程更新:当前奖励加上下一个状态的最大折扣 Q 值
        计算目标值:
            y = r + γ * max_{a'} Q(s', a'; θ^-)

        # 使用目标值 y 和当前 Q 网络 θ 的预测值更新 Q 网络
        # 损失函数计算预测 Q 值与目标 Q 值之间的差距
        更新 Q 网络,最小化损失:
            L(θ) = (y - Q(s, a; θ))^2

        # 每隔 N 步将当前 Q 网络的参数 θ 更新到目标 Q 网络 θ^-
        # 目标网络更新可以稳定训练过程
        每 N 步,更新 θ^- = θ

        # 将下一个状态 s' 设置为当前状态 s
        s = s'

        # 如果当前状态是终止状态,则结束本轮 episode
        if s 是终止状态 then break
    end for
end for

DQN算法在gym环境中实现

1.库导入

import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

 2.定义Q网络

# Define the Q-Network
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_size)
        )

    def forward(self, x):
        return self.fc(x)

 3.定义智能体

# DQN Agent Implementation
class DQNAgent:
    def __init__(self, state_size, action_size, gamma=0.99, epsilon=1.0, epsilon_min=0.1, epsilon_decay=0.995, lr=0.001):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.lr = lr

        self.q_network = QNetwork(state_size, action_size)
        self.target_network = QNetwork(state_size, action_size)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

        self.replay_buffer = deque(maxlen=10000)

    def act(self, state):
        if random.random() < self.epsilon:
            return random.choice(range(self.action_size))
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_network(state_tensor)
        return torch.argmax(q_values).item()

    def remember(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        batch = random.sample(self.replay_buffer, batch_size)

        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Compute target Q values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # Compute current Q values
        q_values = self.q_network(states).gather(1, actions).squeeze()

        # Update Q-network
        loss = self.criterion(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

 4.训练代码

# Train DQN in a Gym Environment
def train_dqn(env_name, episodes=500, batch_size=64, target_update=10):
    env = gym.make(env_name)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = DQNAgent(state_size, action_size)
    rewards_history = []

    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

            agent.replay(batch_size)

        rewards_history.append(total_reward)
        agent.decay_epsilon()

        if episode % target_update == 0:
            agent.update_target_network()

        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")

    env.close()
    return rewards_history

 5.主函数 

# Example usage
if __name__ == "__main__":
    rewards = train_dqn("CartPole-v1", episodes=500)

    # Plot training results
    import matplotlib.pyplot as plt
    plt.plot(rewards)
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("DQN Training on CartPole-v1")
    plt.show()

[Notice]  说明

  1. 核心组件:

    • QNetwork: 定义了一个简单的全连接神经网络,近似 ( Q(s, a) )。
    • DQNAgent: 实现了行为选择、经验存储、经验回放、目标网络更新等功能。
  2. 主要过程:

    • 每次选择动作时遵循 ( \epsilon )-贪心策略,结合探索与利用。
    • 使用经验回放提升训练效率,通过随机采样打破时间相关性。
    • 定期更新目标网络,稳定训练过程。
  3. 环境:

    • 使用 Gym 提供的 CartPole-v1 环境作为测试场景。
  4. 结果:

    1. 训练曲线显示随着训练的进行,智能体逐渐学习到了稳定的策略,总奖励逐步增加。

        由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。


4.重要改进

Double DQN

        解决DQN在估计目标值 \max_{a'} Q(s', a'; \theta^-)时可能存在的过高偏差:

y = r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^-)

Dueling DQN

        引入状态价值函数 ( V(s) )和优势函数( A(s, a) ),分解Q值:

[ Q(s, a) = V(s) + A(s, a) ]

Prioritized Experience Replay

        通过为经验回放分配优先级,提高样本效率。


5.DQN的强化学习背景应用

  • 游戏AI: Atari游戏、围棋、象棋等智能体。
  • 机器人控制: 在动态环境中学习复杂行为。
  • 资源调度: 云计算任务调度、边缘计算优化。
  • 交通管理: 自主驾驶、智能交通信号优化。

参考文献:Playing Atari with Deep Reinforcement Learning

 更多自监督强化学习文章,请前往:【强化学习(RL)】专栏


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨


http://www.kler.cn/a/413930.html

相关文章:

  • 读《Effective Java》笔记 - 条目11
  • 网易博客旧文-----安卓界面代码例子研究(三)
  • 网络技术-服务链编排的介绍和与虚拟化的区别
  • Oracle 执行计划查看方法汇总及优劣对比
  • Elasticsearch与CCS跨集群搜索:深入讲解与实战演练
  • v-for产生 You may have an infinite update loop in a component render function
  • Linux笔记4 磁盘管理
  • 29 基于51单片机的汽车倒车防撞报警器系统
  • 最新保姆级Linux下安装与使用conda:从下载配置到使用全流程
  • 深入解析 PyTorch 的 torch.load() 函数:用法、参数与实际应用示例
  • TypeScript基础语法总结
  • C++:探索哈希表秘密之哈希桶实现哈希
  • 遗传算法与深度学习实战——进化优化的局限性
  • 基于Linux的citus搭建标准化
  • day2全局注册
  • Oracle 19c RAC单节点停机维护硬件
  • 防止按钮被频繁点击
  • 分布式MQTT代理中使用布隆过滤器管理通配符主题
  • Python Turtle召唤童年:小黄人绘画
  • 微服务保护和分布式事务
  • 微信小程序按字母顺序渲染城市 功能实现详细讲解
  • 2024年9月GESPC++一级真题解析
  • springboot配置多数据源mysql+TDengine保姆级教程
  • 探索文件系统,Python os库是你的瑞士军刀
  • C++设计模式——Abstract Factory Pattern抽象工厂模式
  • 介绍一下printf,scanf