当前位置: 首页 > article >正文

【强化学习】PPO算法代码详解

介绍

PPO(Proximal Policy Optimization,近端策略优化)是一种用于强化学习的策略优化算法,由OpenAI在2017年提出。PPO结合了策略梯度方法的优点和信任区域优化(Trust Region Optimization)的思想,旨在实现高效、稳定的策略优化。它已成为强化学习中最常用的算法之一,广泛应用于各种任务,如游戏、机器人控制和自然语言处理等。

PPO的核心目标是通过限制策略更新的幅度,确保每次更新后的策略不会与之前的策略偏离太远,从而避免训练过程中的不稳定性和崩溃。具体来说,PPO通过引入一个“剪裁”(clipping)机制,限制策略更新的幅度,使其在一个安全的范围内进行。

PPO基于策略梯度方法,其目标函数可以表示为: 

L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t \right) \right]

其中:r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}是新旧策略的概率比。 A_t是优势函数,表示当前动作相对于平均表现的优劣。  \epsilon 是一个超参数,用于控制剪裁的范围(通常取值为0.1到0.2)。 剪裁机制的作用是:当 r_t(\theta) 超出 [1-\epsilon, 1+\epsilon] 范围时,目标函数会被限制,从而避免过大的策略更新。

代码

1. 导入所需要的库

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

2. 定义设备

print("============================================================================================")
# 设置设备为 cpu 或 cuda
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("设备设置为 : " + str(torch.cuda.get_device_name(device)))
else:
    print("设备设置为 : cpu")
print("============================================================================================")

3. 经验回放缓冲区

# 经验回放缓冲区
class RolloutBuffer:
    def __init__(self):
        self.actions = []         # 存储动作
        self.states = []          # 存储状态
        self.logprobs = []        # 存储对数概率
        self.rewards = []         # 存储奖励
        self.state_values = []    # 存储状态值
        self.is_terminals = []    # 存储是否终止标记
    
    def clear(self):
        # 清空所有缓存数据
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

4. Actor-Critic 网络

# Actor-Critic 网络
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()
        self.has_continuous_action_space = has_continuous_action_space
        
        # 如果是连续动作空间,则初始化动作方差
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)
        
        # 定义 actor 网络
        if has_continuous_action_space:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Tanh()
                        )
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )
        # 定义 critic 网络
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )
   
    # 设置动作标准差
    def set_action_std(self, new_action_std):
        # 如果是连续动作空间: 更新 self.action_var ,计算新的动作方差
        # 如果是离散动作空间: 打印警告信息,提示该方法不适用于离散动作空间
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
            # 创建一个形状为 (action_dim,) 的张量,并用 action_std_init * action_std_init 填充所有元素
        else:
            print("--------------------------------------------------------------------------------------------")
            print("警告:在离散动作空间策略上调用 ActorCritic::set_action_std()")
            print("--------------------------------------------------------------------------------------------")
    
    # forward 方法未实现,直接抛出 NotImplementedError 异常
    # ActorCritic 类的主要功能通过 act 和 evaluate 方法实现,而不是 forward
    def forward(self):
        raise NotImplementedError
    
    
    def act(self, state):
        # 根据当前状态选择动作并返回动作、动作对数概率和状态值
        if self.has_continuous_action_space:
            action_mean = self.actor(state) # 通过 Actor 网络计算动作的均值
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0) # 构建协方差矩阵,使用 torch.diag 将对角矩阵扩展为合适的形状
            dist = MultivariateNormal(action_mean, cov_mat) # 用于生成动作
        else:
            action_probs = self.actor(state) # 通过 Actor 网络计算动作的概率分布
            dist = Categorical(action_probs) # 用于生成动作

        action = dist.sample() # 从分布中采样一个动作
        action_logprob = dist.log_prob(action) # 计算动作的对数概率
        state_val = self.critic(state) # 通过 Critic 网络评估状态值

        # 返回动作、动作对数概率和状态值,并调用detach()方法断开计算图
        return action.detach(), action_logprob.detach(), state_val.detach()
    
    def evaluate(self, state, action):
        # 评估给定状态和动作下的动作对数概率、状态值和分布熵
        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # 针对单一动作环境进行调整
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
            
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

​​​​​​​​​​​​​​为什么需要两个函数?

  • act 函数 :用于实际与环境交互,生成的动作需要与环境交互,因此不需要计算梯度。
  • evaluate 函数 :用于策略更新,需要计算梯度以优化网络参数。

5. PPO算法

# PPO 算法
class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):
        # 初始化参数
        self.has_continuous_action_space = has_continuous_action_space
        if has_continuous_action_space:
            self.action_std = action_std_init
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.buffer = RolloutBuffer()
        # 初始化当前策略网络和优化器
        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])
        # 初始化旧策略网络,并复制当前策略的参数
        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        # 初始化损失函数
        self.MseLoss = nn.MSELoss()

    # 设置动作标准差
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("警告:在离散动作空间策略上调用 PPO::set_action_std()")
            print("--------------------------------------------------------------------------------------------")

    # 衰减动作标准差
    def decay_action_std(self, action_std_decay_rate, min_action_std):
        print("--------------------------------------------------------------------------------------------")
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if self.action_std <= min_action_std:
                self.action_std = min_action_std
                print("将 actor 输出的 action_std 设置为最小值 : ", self.action_std)
            else:
                print("将 actor 输出的 action_std 设置为 : ", self.action_std)
            self.set_action_std(self.action_std)
        else:
            print("警告:在离散动作空间策略上调用 PPO::decay_action_std()")
        print("--------------------------------------------------------------------------------------------")

    # 根据当前状态选择动作,并将数据存入缓冲区
    def select_action(self, state):
        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.item()

    # 更新策略
    def update(self):
        # 使用蒙特卡洛方法估计回报
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # 对回报进行归一化处理
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # 将列表转换为张量
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # 计算优势值
        advantages = rewards.detach() - old_state_values.detach()

        # 优化策略,进行 K 个 epoch 的训练
        for _ in range(self.K_epochs):
            # 评估旧策略下的动作和状态值
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            state_values = torch.squeeze(state_values)
            # 计算概率比率 (pi_theta / pi_theta_old)
            ratios = torch.exp(logprobs - old_logprobs.detach())
            # 计算代理损失
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            # PPO 剪切目标的最终损失
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            # 反向传播并更新梯度
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # 将当前策略的参数复制给旧策略
        self.policy_old.load_state_dict(self.policy.state_dict())
        # 清空缓冲区
        self.buffer.clear()
    
    def save(self, checkpoint_path):
        # 保存模型参数到指定路径
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        # 从指定路径加载模型参数
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


http://www.kler.cn/a/587020.html

相关文章:

  • java八股文之消息中间件
  • 创业者认知、思辨、成长指南
  • HarmonyOS NEXT个人开发经验总结
  • Docker 镜像和容器相关命令总结
  • Linux第三次练习
  • Qt实现多线程
  • vscode python相对路径的问题
  • 3.6、数字签名
  • Ollama+OpenWebUI本地部署大模型
  • nvm安装node失败的处理方法
  • @RequestParam、@RequestBody、@PathVariable
  • DeepSeek技术解析:MoE架构实现与代码实战
  • 十种处理权重矩阵的方法及数学公式
  • Java注解对象克隆
  • 元音辅音字符串计数leetcode3305,3306
  • 自然语言秒转SQL—— 免费体验 OB Cloud Text2SQL 数据查询
  • 软件行业的“3.15问题”有哪些?如何防止?
  • C++ unordered_map unordered_set 模拟实现
  • Certbot实现SSL免费证书自动续签(CentOS 7版 + Docker部署的nginx)
  • 测试工程师指南:基于需求文档构建本地安全知识库的完整实战