《Python实战进阶》No37: 强化学习入门:Q-Learning 与 DQN-加餐版1 Q-Learning算法可视化
在《Python实战进阶》No37: 强化学习入门:Q-Learning 与 DQN 这篇文章中,我们介绍了Q-Learning算法走出迷宫的代码实践,本文加餐,把Q-Learning算法通过代码可视化呈现。我尝试了使用Matplotlib实现,但局限于Matplotlib对动画不支持,做出来的仿动画太僵硬,所以使用 pygame
重新设计 Q-Learning 的可视化程序可以显著提升动画的流畅性和交互性。相比于 matplotlib
,pygame
更适合处理实时动画和游戏化的内容。以下是一个完整的基于 pygame
的实现方案,
视频:Q-Learning算法训练可视化
目标
- 迷宫布局:动态绘制迷宫(包括起点、终点和墙壁)。
- 智能体移动:实时更新智能体的位置。
- 最优路径:训练完成后显示从起点到终点的最优路径。
- 最终目标:完整呈现Q-Learning算法的训练过程。
实现步骤
步骤 1:安装依赖
确保安装了 pygame
库:
pip install pygame
步骤 2:修改迷宫环境
我们对迷宫环境进行一些扩展,以便更好地支持 pygame
可视化。
import numpy as np
class MazeEnv:
def __init__(self):
self.maze = [
['.', '.', '.', '#', '.'],
['.', '#', '.', '.', '.'],
['.', '#', '.', '#', '.'],
['.', '.', '.', '#', '.'],
['.', '#', 'G', '#', '.']
]
self.maze = np.array(self.maze)
self.start = (0, 0)
self.goal = (4, 2)
self.current_state = self.start
self.actions = [(0, 1), (0, -1), (1, 0), (-1, 0)] # 右、左、下、上
def reset(self):
self.current_state = self.start
return self.current_state
def step(self, action):
next_state = (self.current_state[0] + action[0], self.current_state[1] + action[1])
if (
next_state[0] < 0 or next_state[0] >= self.maze.shape[0] or
next_state[1] < 0 or next_state[1] >= self.maze.shape[1] or
self.maze[next_state] == '#'
):
next_state = self.current_state # 如果撞墙,保持原位置
reward = -1 # 每步移动的默认奖励
done = False
if next_state == self.goal:
reward = 10 # 到达终点的奖励
done = True
self.current_state = next_state
return next_state, reward, done
def get_maze_size(self):
return self.maze.shape
def is_wall(self, position):
return self.maze[position] == '#'
def is_goal(self, position):
return position == self.goal
步骤 3:设计 pygame
可视化程序
以下是基于 pygame
的完整可视化代码:
import pygame
import time
import random
import numpy as np
# 初始化 pygame
pygame.init()
# 定义颜色
WHITE = (255, 255, 255) # 空地
BLACK = (0, 0, 0) # 墙壁
GREEN = (0, 255, 0) # 终点
RED = (255, 0, 0) # 智能体
BLUE = (0, 0, 255) # 最优路径
# 定义单元格大小
CELL_SIZE = 50
FPS = 10 # 动画帧率
def visualize_with_pygame(env, agent, num_episodes=1000):
rows, cols = env.get_maze_size()
screen_width = cols * CELL_SIZE
screen_height = rows * CELL_SIZE
# 初始化屏幕
screen = pygame.display.set_mode((screen_width, screen_height))
pygame.display.set_caption("Q-Learning Maze Visualization")
clock = pygame.time.Clock()
def draw_maze():
for i in range(rows):
for j in range(cols):
rect = pygame.Rect(j * CELL_SIZE, i * CELL_SIZE, CELL_SIZE, CELL_SIZE)
if env.is_wall((i, j)):
pygame.draw.rect(screen, BLACK, rect)
elif env.is_goal((i, j)):
pygame.draw.rect(screen, GREEN, rect)
else:
pygame.draw.rect(screen, WHITE, rect)
def draw_agent(position):
x, y = position
center = (y * CELL_SIZE + CELL_SIZE // 2, x * CELL_SIZE + CELL_SIZE // 2)
pygame.draw.circle(screen, RED, center, CELL_SIZE // 3)
def draw_path(path):
for (x, y) in path:
rect = pygame.Rect(y * CELL_SIZE, x * CELL_SIZE, CELL_SIZE, CELL_SIZE)
pygame.draw.rect(screen, BLUE, rect)
# 训练过程可视化
for episode in range(num_episodes):
state = env.reset()
done = False
path = [state]
while not done:
# 处理退出事件
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
return
# 清屏并绘制迷宫
screen.fill(WHITE)
draw_maze()
# 获取动作
action = agent.get_action(state)
next_state, reward, done = env.step(action)
agent.update_q_table(state, action, reward, next_state)
state = next_state
path.append(state)
# 绘制智能体
draw_agent(state)
# 更新屏幕
pygame.display.flip()
clock.tick(FPS)
if episode % 100 == 0:
print(f"Episode {episode}: Training...")
# 测试过程可视化
state = env.reset()
done = False
path = [state]
while not done:
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
return
screen.fill(WHITE)
draw_maze()
action = agent.get_action(state)
state, _, done = env.step(action)
path.append(state)
draw_agent(state)
pygame.display.flip()
clock.tick(FPS)
# 显示最终路径
screen.fill(WHITE)
draw_maze()
draw_path(path)
pygame.display.flip()
# 等待用户关闭窗口
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
pygame.quit()
步骤 4:集成到 Q-Learning 算法
将 pygame
可视化函数集成到 Q-Learning 的训练和测试过程中。
class QLearningAgent:
def __init__(self, env, learning_rate=0.1, discount_factor=0.9, epsilon=0.1):
self.env = env
self.q_table = {}
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
def get_action(self, state):
if random.uniform(0, 1) < self.epsilon:
return random.choice(self.env.actions) # 探索
else:
q_values = [self.get_q_value(state, action) for action in self.env.actions]
return self.env.actions[np.argmax(q_values)] # 贪婪策略
def get_q_value(self, state, action):
key = (state, action)
return self.q_table.get(key, 0.0)
def update_q_table(self, state, action, reward, next_state):
old_q = self.get_q_value(state, action)
max_next_q = max([self.get_q_value(next_state, a) for a in self.env.actions])
new_q = old_q + self.learning_rate * (reward + self.discount_factor * max_next_q - old_q)
self.q_table[(state, action)] = new_q
步骤 5:运行代码
创建迷宫环境和智能体,并运行训练和测试代码。
# 创建环境和智能体
env = MazeEnv()
agent = QLearningAgent(env)
# 使用 pygame 可视化训练和测试
visualize_with_pygame(env, agent, num_episodes=1000)
效果
- 流畅的动画:
pygame
提供了高效的绘图性能,动画更加流畅。 - 实时更新:智能体的位置和路径会实时更新,清晰展示学习过程。
- 交互性:用户可以通过关闭窗口随时停止程序。
扩展功能
- 优化动画速度:通过调整
FPS
和clock.tick()
控制动画速度。 - 添加热力图:使用不同颜色表示 Q 值表的变化。
- 支持更大迷宫:通过缩放单元格大小(
CELL_SIZE
)适应更大迷宫。
通过以上方法,你可以实现一个高效且流畅的 Q-Learning 可视化程序!