当前位置: 首页 > article >正文

强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN

强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN

      • 1. 导入依赖
      • 2. 定义 Q 网络
      • 3. 创建 Agent
      • 4. 训练过程
      • 5. 解释
      • 6. 调整超参数

CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们需要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中过度估计的偏差。

下面是一个基于 PyTorch 的 Double DQN 实现:

1. 导入依赖

import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque

2. 定义 Q 网络

我们需要定义一个 Q 网络,用于计算 Q 值。这里使用简单的全连接网络。

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

3. 创建 Agent

class DoubleDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.lr = lr

        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)

        self.memory = deque(maxlen=10000)
        self.batch_size = 64

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.choice(range(self.action_dim))  # Explore
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state)
            return torch.argmax(q_values).item()  # Exploit

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample_batch(self):
        return random.sample(self.memory, self.batch_size)

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        batch = self.sample_batch()
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Q values for current states
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Next Q values using target network
        next_q_values = self.target_network(next_states)
        next_actions = self.q_network(next_states).argmax(1)
        next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)

        # Double DQN update
        target = rewards + (1 - dones) * self.gamma * next_q_values

        # Compute loss
        loss = nn.MSELoss()(q_values, target)

        # Optimize the Q-network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

4. 训练过程

def train_cartpole():
    env = gym.make('CartPole-v1')
    agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)

    episodes = 1000
    for episode in range(episodes):
        state, info = env.reset()
        done = False
        total_reward = 0

        while not done:
            action = agent.select_action(state)
            next_state, reward, done, truncated, info = env.step(action)
            agent.store_experience(state, action, reward, next_state, done)
            state = next_state

            agent.train()
            total_reward += reward

        agent.update_target_network()

        print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")

    env.close()

if __name__ == '__main__':
    train_cartpole()

5. 解释

  • QNetwork: 使用一个简单的 3 层全连接神经网络来近似 Q 函数。
  • DoubleDQNAgent:
    • select_action: 根据 ε-greedy 策略选择动作。
    • store_experience: 存储经验回放。
    • sample_batch: 从记忆中随机采样批次。
    • train: 更新 Q 网络的权重,使用 Double DQN 的目标计算方法。
    • update_target_network: 每一定步数更新目标网络。
  • 训练过程: 在每一回合中,代理与环境互动并更新 Q 网络,通过经验回放机制逐步学习。

6. 调整超参数

  • gamma: 折扣因子,控制未来奖励的影响。
  • epsilon: 初始的探索率,随着训练的进行逐渐减小。
  • lr: 学习率,控制权重更新的步伐。
  • batch_size: 每次更新时,从记忆库中采样的批量大小。

这个代码可以直接用于训练一个 CartPole 的 Double DQN 代理,逐步优化 Q 网络来完成游戏任务。如果你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。


http://www.kler.cn/a/504044.html

相关文章:

  • ip属地是根据手机号还是位置
  • LabVIEW智能水肥一体灌溉控制系统
  • 数仓建模(五)选择数仓技术栈:Hive ClickHouse 其它
  • 目标检测中的Bounding Box(边界框)介绍:定义以及不同表示方式
  • Termora 一个开源的 SSH 跨平台客户端工具
  • 快速、可靠且高性价比的定制IP模式提升芯片设计公司竞争力
  • Linux第二课:LinuxC高级 学习记录day03
  • PHP Filesystem:深入解析与实战应用
  • 【机器学习】聚类评价指标之福尔克斯–马洛斯指数(Fowlkes–Mallows Index, FMI)
  • 说一说mongodb组合索引的匹配规则
  • 从github上,下载的android项目,从0-1进行编译运行-踩坑精力,如何进行部署
  • 65.在 Vue 3 中使用 OpenLayers 绘制带有箭头的线条
  • 伏羲1.0试用版(文生图)
  • 【软件工程】知识点总结(下)
  • 基于python的舆情监测管理系统
  • phpstorm jetbrain 配置review code
  • React 中事件机制详细介绍:概念与执行流程如何更好的理解
  • 软件测试 —— 自动化测试(Selenium)
  • element-ui dialog弹窗 设置点击空白处不关闭
  • 【Redis】初识Redis
  • 机器学习赋能的智能光子学器件系统研究与应用
  • Spring Boot 项目启动后自动加载系统配置的多种实现方式
  • 202305 青少年软件编程等级考试C/C++ 二级真题答案及解析(电子学会)
  • 本地服务器Docker搭建个人云音乐平台Splayer并实现远程访问告别烦人广告
  • mapbox进阶,添加绘图控件
  • NHANES数据挖掘|特征变量对死亡率预测的研究设计与分析