强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN
强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN
- 1. 导入依赖
- 2. 定义 Q 网络
- 3. 创建 Agent
- 4. 训练过程
- 5. 解释
- 6. 调整超参数
在 CartPole
游戏中实现 Double DQN(DDQN)训练网络时,我们需要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning
中过度估计的偏差。
下面是一个基于 PyTorch
的 Double DQN 实现:
1. 导入依赖
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque
2. 定义 Q 网络
我们需要定义一个 Q 网络,用于计算 Q 值。这里使用简单的全连接网络。
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
3. 创建 Agent
class DoubleDQNAgent:
def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.lr = lr
self.q_network = QNetwork(state_dim, action_dim)
self.target_network = QNetwork(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
self.memory = deque(maxlen=10000)
self.batch_size = 64
def select_action(self, state):
if random.random() < self.epsilon:
return random.choice(range(self.action_dim)) # Explore
else:
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
q_values = self.q_network(state)
return torch.argmax(q_values).item() # Exploit
def store_experience(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def sample_batch(self):
return random.sample(self.memory, self.batch_size)
def update_target_network(self):
self.target_network.load_state_dict(self.q_network.state_dict())
def train(self):
if len(self.memory) < self.batch_size:
return
batch = self.sample_batch()
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Q values for current states
q_values = self.q_network(states)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Next Q values using target network
next_q_values = self.target_network(next_states)
next_actions = self.q_network(next_states).argmax(1)
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
# Double DQN update
target = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = nn.MSELoss()(q_values, target)
# Optimize the Q-network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
4. 训练过程
def train_cartpole():
env = gym.make('CartPole-v1')
agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
episodes = 1000
for episode in range(episodes):
state, info = env.reset()
done = False
total_reward = 0
while not done:
action = agent.select_action(state)
next_state, reward, done, truncated, info = env.step(action)
agent.store_experience(state, action, reward, next_state, done)
state = next_state
agent.train()
total_reward += reward
agent.update_target_network()
print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")
env.close()
if __name__ == '__main__':
train_cartpole()
5. 解释
- QNetwork: 使用一个简单的 3 层全连接神经网络来近似 Q 函数。
- DoubleDQNAgent:
select_action
: 根据 ε-greedy 策略选择动作。store_experience
: 存储经验回放。sample_batch
: 从记忆中随机采样批次。train
: 更新 Q 网络的权重,使用 Double DQN 的目标计算方法。update_target_network
: 每一定步数更新目标网络。
- 训练过程: 在每一回合中,代理与环境互动并更新 Q 网络,通过经验回放机制逐步学习。
6. 调整超参数
gamma
: 折扣因子,控制未来奖励的影响。epsilon
: 初始的探索率,随着训练的进行逐渐减小。lr
: 学习率,控制权重更新的步伐。batch_size
: 每次更新时,从记忆库中采样的批量大小。
这个代码可以直接用于训练一个 CartPole 的 Double DQN 代理,逐步优化 Q 网络来完成游戏任务。如果你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。