当前位置: 首页 > article >正文

pytorch实现门控循环单元 (GRU)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt

# 🏁 迷宫环境(5×5)
class MazeEnv:
    def __init__(self, size=5):
        self.size = size
        self.state = (0, 0)  # 起点
        self.goal = (size-1, size-1)  # 终点
        self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上
    
    def reset(self):
        self.state = (0, 0)  # 重置起点
        return self.state

    def step(self, action):
        dx, dy = self.actions[action]
        x, y = self.state
        nx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))
        
        reward = 1 if (nx, ny) == self.goal else -0.1
        done = (nx, ny) == self.goal
        
        self.state = (nx, ny)
        return (nx, ny), reward, done

# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUPolicy, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.gru(x, hidden)
        out = self.fc(out[:, -1, :])  # 只取最后时间步
        return out, hidden

# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径

for episode in range(num_episodes):
    state = env.reset()
    hidden = torch.zeros(1, 1, 16)  # GRU 初始状态
    states, actions, rewards = [], [], []
    logits_list = []  

    for _ in range(20):  # 最多 20 步
        state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)
        logits, hidden = policy(state_tensor, hidden)
        logits_list.append(logits)

        # ε-greedy 策略
        if random.random() < epsilon:
            action = random.choice(range(4))  # 随机选择动作
        else:
            action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作

        next_state, reward, done = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)

        if done:
            print(f"Episode {episode} - Reached Goal!")
            # 找到最优路径
            best_path = states + [next_state]  # 当前 episode 的路径
            break
        state = next_state

    # 计算损失
    logits = torch.cat(logits_list, dim=0)  # (T, 4)
    action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)
    loss = loss_fn(logits, action_tensor)  

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 衰减 ε
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    if episode % 100 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")

# 🧐 确保 best_path 已经记录
if len(best_path) == 0:
    print("No path found during training.")
else:
    print(f"Best path: {best_path}")

# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))

# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")

# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)

# 画出最佳路径(红色)
for (x, y) in best_path:
    ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))

# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")

plt.title("GRU RL Agent - Best Path")
plt.show()


http://www.kler.cn/a/528313.html

相关文章:

  • 《苍穹外卖》项目学习记录-Day10订单状态定时处理
  • C++ Primer 自定义数据结构
  • DeepSeek-R1论文研读:通过强化学习激励LLM中的推理能力
  • 51单片机入门_01_单片机(MCU)概述(使用STC89C52芯片;使用到的硬件及课程安排)
  • DeepSeek 云端部署,释放无限 AI 潜力!
  • CUDA学习-内存访问
  • 《深入理解HTTP交互与数据监控:完整流程与优化实践》
  • FreeRTOS学习 --- 中断管理
  • 写好简历的三个关键认知
  • NVIDIA (英伟达)的 GPU 产品应用领域
  • Kafka 使用说明(kafka官方文档中文)
  • 30.Word:设计并制作新年贺卡以及标签【30】
  • 2025.1.31总结
  • Kafka SASL/PLAIN介绍
  • Python-基于PyQt5,wordcloud,pillow,numpy,os,sys等的智能词云生成器
  • Python中的数据类(dataclass):简化类的定义与数据管理
  • Unity 程序集
  • gentoo 中更改$PS1
  • 【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据
  • 跨境支付领域中常用的英文单词(持续更新)
  • 【JAVA基础】双亲委派
  • 携程Android开发面试题及参考答案
  • docker直接运行arm下的docker
  • 冯·诺依曼体系结构
  • 基于Python的人工智能患者风险评估预测模型构建与应用研究(上)
  • Vue- 组件通信2