当前位置: 首页 > article >正文

强化学习策略梯度算法实现文档(CartPole-v1)

1. 概述

本代码使用策略梯度方法(Policy Gradient)解决OpenAI Gym的CartPole-v1环境问题,包含以下核心组件:

  • 策略网络:神经网络输出动作概率分布

  • REINFORCE算法:带熵正则化的策略梯度方法

  • 训练监控:实时奖励跟踪与模型保存

  • 可视化:训练过程曲线与策略演示


2. 环境说明

python

复制

env = gym.make('CartPole-v1')
  • 状态空间:4维连续向量 [车位置, 车速, 杆角度, 杆角速度]

  • 动作空间:2个离散动作(左推/右推)

  • 奖励机制:每步存活奖励+1,最大步长500

  • 终止条件

    • 杆倾斜超过15度

    • 车移动超出±2.4单位

    • 连续存活500步(成功)


3. 策略网络架构

python

复制

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
  • 输入层:4个神经元(与环境状态维度一致)

  • 隐藏层:2个全连接层(128神经元),ReLU激活

  • 输出层:2个神经元(对应动作数),输出logits

  • 设计特点

    • 无最终softmax层(由Categorical分布自动处理)

    • 深度结构增强表征能力


4. 训练算法实现
4.1 核心参数
参数作用
gamma0.99未来奖励折扣因子
entropy_coef0.01熵正则化系数
lr1e-3Adam优化器学习率
max_norm0.5梯度裁剪阈值
4.2 训练流程

python

复制

def train(...):
    # 数据收集阶段
    while not done:
        prob_dist = Categorical(logits=policy_net(state_tensor))
        action = prob_dist.sample()
        # 存储log_prob, entropy, reward等

    # 回报计算
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    # 损失计算
    policy_loss = -log_prob * R - entropy_coef * entropy

    # 梯度更新
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(...)
    optimizer.step()
  1. 经验收集

    • 使用当前策略采样轨迹

    • 记录状态、动作、奖励、对数概率、熵值

  2. 回报计算

    • 折扣累计奖励:Rt=∑k=0Tγkrt+kRt​=∑k=0T​γkrt+k​

    • 标准化处理:R~t=(Rt−μR)/σRR~t​=(Rt​−μR​)/σR​

  3. 损失函数

    • 策略梯度损失:LPG=−E[log⁡π(a∣s)R~]LPG​=−E[logπ(a∣s)R~]

    • 熵正则项:Lent=−βH(π(⋅∣s))Lent​=−βH(π(⋅∣s))

    • 总损失:Ltotal=LPG+LentLtotal​=LPG​+Lent​

  4. 优化步骤

    • 梯度裁剪防止爆炸

    • Adam优化器更新参数


5. 关键技术点
5.1 熵正则化

python

复制

entropy = prob_dist.entropy()
policy_loss.append(... - entropy_coef * entropy)
  • 作用:增加探索,防止策略过早收敛

  • 效果:保持动作概率分布分散度

5.2 梯度裁剪

python

复制

torch.nn.utils.clip_grad_norm_(..., max_norm=0.5)
  • 原理:限制梯度L2范数不超过阈值

  • 优势:提升训练稳定性

5.3 状态标准化

python

复制

returns = (returns - returns.mean()) / (...)
  • 目的:减少回报方差

  • 注意:保留少量常数(1e-8)防止除零错误


6. 训练监控与评估
6.1 进度跟踪

python

复制

if (episode + 1) % 50 == 0:
    avg_reward = np.mean(episode_rewards[-50:])
    if avg_reward >= env.spec.reward_threshold:  # 默认阈值475
        print(f"Solved in {episode + 1} episodes!")
  • 输出频率:每50轮显示平均奖励

  • 停止条件:最近50轮平均奖励≥475

6.2 模型保存

python

复制

torch.save(policy_net.state_dict(), 'cartpole_policy.pth')
  • 格式:PyTorch模型参数

  • 用途:后续部署或继续训练

6.3 策略测试

python

复制

def test_policy(...):
    action = policy_net(...).argmax().item()
    env.render()
  • 策略选择:贪婪策略(取最大概率动作)

  • 渲染显示:可视化杆平衡过程


7. 可视化输出

python

复制

plt.plot(rewards)
plt.title('CartPole Training Progress')
  • X轴:训练轮次

  • Y轴:单轮总奖励

  • 典型曲线

    训练曲线示例


8. 运行与调优
8.1 执行命令

bash

复制

python cartpole_pg.py
8.2 预期输出

text

复制

Episode 50, Avg Reward (last 50): 42.3
Episode 100, Avg Reward (last 50): 195.2
Solved in 127 episodes!
8.3 调优建议
  • 学习率:尝试1e-4 ~ 3e-3范围

  • 网络结构:调整隐藏层维度(64-256)

  • 熵系数:0.001-0.1之间调节

  • 折扣因子:0.95-0.999


9. 扩展应用
  • 更换环境:适配MountainCar、LunarLander等离散动作环境

  • 算法改进

    • 添加基线(Baseline)减少方差

    • 实现PPO/TRPO等高级策略梯度方法

  • 分布式训练:使用多环境并行采样

此实现完整展示了策略梯度方法的核心思想,可作为强化学习基础实验平台。

完整代码

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import matplotlib.pyplot as plt  # 添加导入语句

# 定义策略网络(增加层数和激活函数)
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
        
    def forward(self, x):
        return self.fc(x)

# 改进的训练函数(修复梯度计算,添加熵正则化)
def train(env, policy_net, optimizer, num_episodes=1500, gamma=0.99, entropy_coef=0.01):
    episode_rewards = []
    
    for episode in range(num_episodes):
        state, _ = env.reset()  # 适配新版gym API
        states, actions, rewards, log_probs, entropies = [], [], [], [], []
        done = False
        
        # 收集轨迹数据
        while not done:
            state_tensor = torch.FloatTensor(state)
            logits = policy_net(state_tensor)
            prob_dist = Categorical(logits=logits)
            
            action = prob_dist.sample()
            log_prob = prob_dist.log_prob(action)
            entropy = prob_dist.entropy()
            
            # 执行动作(适配新版gym API)
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            
            # 存储数据
            states.append(state_tensor)
            actions.append(action)
            rewards.append(reward)
            log_probs.append(log_prob)
            entropies.append(entropy)
            
            state = next_state
        
        # 计算折扣回报
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        
        # 标准化回报
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        # 计算损失
        policy_loss = []
        for log_prob, R, entropy in zip(log_probs, returns, entropies):
            policy_loss.append(-log_prob * R - entropy_coef * entropy)
        
        total_loss = torch.stack(policy_loss).sum()
        
        # 反向传播
        optimizer.zero_grad()
        total_loss.backward()
        
        # 梯度裁剪(防止梯度爆炸)
        torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=0.5)
        optimizer.step()
        
        # 记录训练进度
        total_reward = sum(rewards)
        episode_rewards.append(total_reward)
        
        # 显示训练进度
        if (episode + 1) % 50 == 0:
            avg_reward = np.mean(episode_rewards[-50:])
            print(f"Episode {episode + 1}, Avg Reward (last 50): {avg_reward:.1f}")
            if avg_reward >= env.spec.reward_threshold:
                print(f"Solved in {episode + 1} episodes!")
                break
    
    return episode_rewards

# 主函数(添加模型保存和测试功能)
if __name__ == "__main__":
    # 创建环境
    env = gym.make('CartPole-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    # 初始化网络和优化器
    policy_net = PolicyNetwork(state_dim, action_dim)
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
    
    # 训练模型
    rewards = train(env, policy_net, optimizer, num_episodes=1000)
    
    # 保存模型
    torch.save(policy_net.state_dict(), 'cartpole_policy.pth')
    
    # 测试训练好的策略
    def test_policy(env, policy_net, episodes=10):
        for _ in range(episodes):
            state, _ = env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    action = policy_net(torch.FloatTensor(state)).argmax().item()
                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                env.render()
            print(f"Test episode finished")
        env.close()
    
    test_policy(env, policy_net)
    
    # 绘制训练曲线
    plt.plot(rewards)
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('CartPole Training Progress')
    plt.show()


http://www.kler.cn/a/569890.html

相关文章:

  • barcodelib:一个功能强大且易于使用的 C# 条形码生成库
  • 2025全开源Java多语言跨境电商外贸商城/Tk/FB内嵌商城I商家入驻I批量下单I完美运行
  • 【QT网络问题】关于QT在调用天气等类似api接口时报错
  • 差旅费控平台作用、功能、11款主流产品优劣势对比
  • Docker 数据卷管理及优化
  • 【网络安全 | 渗透测试】GraphQL精讲二:发现API漏洞
  • CAN总线通信协议学习4——数据链路层之仲裁规则
  • 【大模型原理与技术】1.2基于学习的语言模型
  • Yocto + 树莓派摄像头驱动完整指南
  • 如何为Java面试准备项目经验
  • 【告别双日期面板!一招实现el-date-picker智能联动日期选择】
  • 初探Ollama与deepseek
  • 【GESP】C++二级真题 luogu-B4037 [GESP202409 二级] 小杨的 N 字矩阵
  • 【无人机】无人机通信模块,无人机图数传模块的介绍,数传,图传,图传数传一体电台,
  • Windows Docker玩转Nginx,从零配置到自定义欢迎页
  • 三元组排序(acwing)c++
  • 关于后端使用Boolean或boolean时前端收到的参数的区别
  • Spring(历史)
  • 基于STM32的智能家居蓝牙系统(论文+源码)
  • Vue 表单优化:下拉框值改变前的确认提示与还原逻辑实现