强化学习寻宝游戏
代码
import numpy as np
import tkinter as tk
import time
import random
# 自定义寻宝游戏环境
class TreasureHuntEnv:
def __init__(self):
self.grid_size = 10 # 网格大小
self.actions = ['up', 'down', 'left', 'right'] # 可能的动作
self.reset()
def reset(self):
"""重置环境,返回初始状态"""
# 初始化网格
self.grid = np.zeros((self.grid_size, self.grid_size), dtype=int)
# 放置智能体
self.agent_pos = (0, 0)
self.grid[self.agent_pos] = 1
# 放置宝藏
self.treasure_pos = (self.grid_size - 1, self.grid_size - 1)
self.grid[self.treasure_pos] = 2
# 放置障碍物
for _ in range(10):
x, y = random.randint(0, self.grid_size - 1), random.randint(0, self.grid_size - 1)
if (x, y) not in [self.agent_pos, self.treasure_pos]:
self.grid[x, y] = 3
# 放置陷阱
for _ in range(5):
x, y = random.randint(0, self.grid_size - 1), random.randint(0, self.grid_size - 1)
if (x, y) not in [self.agent_pos, self.treasure_pos] and self.grid[x, y] != 3:
self.grid[x, y] = 4
return self.agent_pos
def step(self, action):
"""执行动作,返回新状态、奖励和是否结束"""
x, y = self.agent_pos
if action == 'up':
x -= 1
elif action == 'down':
x += 1
elif action == 'left':
y -= 1
elif action == 'right':
y += 1
# 检查新位置是否合法
if x < 0 or x >= self.grid_size or y < 0 or y >= self.grid_size: # 超出边界
x, y = self.agent_pos
reward = -1
done = False
elif self.grid[x, y] == 3: # 撞到障碍物
x, y = self.agent_pos
reward = -1
done = False
elif self.grid[x, y] == 4: # 踩到陷阱
reward = -5
done = False
elif self.grid[x, y] == 2: # 找到宝藏
reward = 10
done = True
else: # 空地
reward = -0.1
done = False
# 更新智能体位置
self.grid[self.agent_pos] = 0
self.agent_pos = (x, y)
self.grid[self.agent_pos] = 1
return (x, y), reward, done
# Q-Learning 智能体
class QLearningAgent:
def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
self.env = env
self.alpha = alpha # 学习率
self.gamma = gamma # 折扣因子
self.epsilon = epsilon # 探索率
self.q_table = {} # Q 表
def get_q_value(self, state, action):
"""获取 Q 值"""
return self.q_table.get((state, action), 0.0)
def choose_action(self, state):
"""选择动作 (epsilon-贪婪策略)"""
if np.random.rand() < self.epsilon:
return np.random.choice(self.env.actions) # 随机探索
else:
q_values = [self.get_q_value(state, a) for a in self.env.actions]
return self.env.actions[np.argmax(q_values)] # 选择最优动作
def update_q_value(self, state, action, reward, next_state):
"""更新 Q 值"""
old_q = self.get_q_value(state, action)
max_next_q = max([self.get_q_value(next_state, a) for a in self.env.actions])
new_q = old_q + self.alpha * (reward + self.gamma * max_next_q - old_q)
self.q_table[(state, action)] = new_q
def train(self, episodes=1000, render=False):
"""训练智能体"""
for episode in range(episodes):
state = self.env.reset()
done = False
total_reward = 0
while not done:
action = self.choose_action(state)
next_state, reward, done = self.env.step(action)
self.update_q_value(state, action, reward, next_state)
state = next_state
total_reward += reward
if render:
render_grid(self.env.grid)
time.sleep(0.01) # 控制渲染速度
print(f"Episode: {episode + 1}, Total Reward: {total_reward}")
# 使用 tkinter 渲染网格
def render_grid(grid):
canvas.delete("all") # 清空画布
cell_size = 50 # 每个格子的大小
for i in range(grid.shape[0]):
for j in range(grid.shape[1]):
x1, y1 = j * cell_size, i * cell_size
x2, y2 = x1 + cell_size, y1 + cell_size
if grid[i, j] == 1: # 智能体
canvas.create_rectangle(x1, y1, x2, y2, fill="blue")
elif grid[i, j] == 2: # 宝藏
canvas.create_rectangle(x1, y1, x2, y2, fill="gold")
elif grid[i, j] == 3: # 障碍物
canvas.create_rectangle(x1, y1, x2, y2, fill="black")
elif grid[i, j] == 4: # 陷阱
canvas.create_rectangle(x1, y1, x2, y2, fill="red")
else: # 空地
canvas.create_rectangle(x1, y1, x2, y2, fill="white")
root.update() # 更新界面
# 主程序
if __name__ == "__main__":
# 初始化 tkinter
root = tk.Tk()
root.title("Treasure Hunt Environment")
canvas = tk.Canvas(root, width=500, height=500)
canvas.pack()
# 初始化环境和智能体
env = TreasureHuntEnv()
agent = QLearningAgent(env)
# 训练智能体并渲染
agent.train(episodes=100, render=True)
# 运行 tkinter 主循环
root.mainloop()
代码说明
环境设计:
10x10 的网格世界,包含智能体、宝藏、障碍物和陷阱。
智能体需要避开障碍物和陷阱,找到宝藏。
奖励机制:
找到宝藏:+10
踩到陷阱:-5
每移动一步:-0.1
渲染:
使用 tkinter 动态渲染网格世界。
智能体为蓝色,宝藏为金色,障碍物为黑色,陷阱为红色。
训练与渲染:
在训练过程中,调用 render_grid 函数实时更新网格世界。
使用 time.sleep(0.01) 控制渲染速度。