当前位置: 首页 > article >正文

【RL Base】强化学习:信赖域策略优化(TRPO)算法

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】(51)---《信赖域策略优化(TRPO)算法》

信赖域策略优化(TRPO)算法

目录

1.信赖域策略优化

2.什么是 TRPO?

3.TRPO 的公式解析

1. 策略优化目标

2. KL 散度约束

3. Fisher 信息矩阵

4. 二次约束优化

[Python] TRPO算法实现

1.TRPO算法伪代码

2.TRPO算法pytorch实现代码

[Notice]  注意事项

4.TRPO 的优势

5.TRPO 的应用场景

6.结论


1.信赖域策略优化

        在强化学习(RL)领域,如何稳定地优化策略是一个核心挑战。2015 年,由 John Schulman 等人提出的信赖域策略优化(Trust Region Policy Optimization, TRPO)算法为这一问题提供了优雅的解决方案。TRPO 通过限制策略更新的幅度,避免了策略更新过大导致的不稳定问题,是强化学习中经典的策略优化方法之一。


2.什么是 TRPO?

        TRPO 是一种基于策略梯度的优化算法,其目标是通过限制新策略和旧策略之间的差异来确保训练的稳定性。TRPO 在高维、连续动作空间中表现尤为出色,尤其适用于机器人控制、游戏 AI 等领域。

基本思想

        TRPO 的核心目标是寻找新策略 ( \pi_\theta ),使得累积奖励最大化,同时限制新旧策略之间的变化幅度。数学形式化如下:

\max_\theta \mathbb{E}{s \sim \rho{\pi_\text{old}}, a \sim \pi_\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} A_{\pi_\text{old}}(s, a) \right]

其中:

  • ( \pi_\text{old} ): 旧策略。
  • ( \pi_\theta ): 新策略。
  • ( A_{\pi_\text{old}}(s, a) ): 优势函数,衡量动作( a ) 相对于状态( s )的优劣。

TRPO 在优化过程中通过限制新旧策略之间的 KL 散度来实现稳定更新:

[ \mathbb{E}{s \sim \rho{\pi_\text{old}}} \left[ \text{KL}(\pi_\text{old} | \pi_\theta) \right] \leq \delta ]


3.TRPO 的公式解析

1. 策略优化目标

TRPO 的优化目标是:

L(\pi_\theta) = \mathbb{E}{s, a \sim \pi\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} A_{\pi_\text{old}}(s, a) \right]

其中:

  • ( \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} ): 概率比率,衡量新旧策略的变化。
  • ( A_{\pi_\text{old}}(s, a) ): 优势函数,用于评价动作的相对收益。

2. KL 散度约束

        通过添加 KL 散度约束,TRPO 控制新策略( \pi_\theta )与旧策略( \pi_\text{old} )的差异:

\mathbb{E}{s \sim \rho{\pi_\text{old}}} \left[ \text{KL}(\pi_\text{old} | \pi_\theta) \right] \leq \delta

  • ( \text{KL}(\pi_\text{old} | \pi_\theta) ): 衡量新旧策略概率分布的差异。
  • ( \delta ): 预设的信赖域阈值。

3. Fisher 信息矩阵

        KL 散度的二阶导数可以通过 Fisher 信息矩阵近似表示:

H = \mathbb{E}{s \sim \rho{\pi_\text{old}}, a \sim \pi_\text{old}} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]

4. 二次约束优化

        TRPO 的优化过程最终转化为一个带二次约束的优化问题:

\max_\theta g^T \delta\theta \quad \text{subject to } \delta\theta^T H \delta\theta \leq \delta

  • ( g = \nabla_\theta L(\pi_\theta) ): 策略梯度。
  • ( H ): Fisher 信息矩阵,用于表示 KL 散度的变化。

通过拉格朗日乘数法,TRPO 可以高效求解上述问题。


[Python] TRPO算法实现

1.TRPO算法伪代码

"""《TRPO 的实现流程》
    时间:2024.11
    作者:不去幼儿园
"""
初始化策略参数 θ 和目标网络
for 迭代轮次 do
    1. 使用策略 π_θ 采样轨迹 (s, a, r)
    2. 计算优势函数 A(s, a) (可以用 GAE 方法)
    3. 计算策略梯度 g = ∇_θ L(θ)
    4. 估计 Fisher 信息矩阵 H
    5. 通过优化问题 max δθ^T g, subject to δθ^T H δθ ≤ δ 更新策略
    6. 更新目标网络
end for

2.TRPO算法pytorch实现代码

import argparse
from itertools import count

import gym
import scipy.optimize

import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from trpo import trpo_step
from utils import *

torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

torch.set_default_tensor_type('torch.DoubleTensor')

parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.995, metavar='G',
                    help='discount factor (default: 0.995)')
parser.add_argument('--env-name', default="Reacher-v1", metavar='G',
                    help='name of the environment to run')
parser.add_argument('--tau', type=float, default=0.97, metavar='G',
                    help='gae (default: 0.97)')
parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G',
                    help='l2 regularization regression (default: 1e-3)')
parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G',
                    help='max kl value (default: 1e-2)')
parser.add_argument('--damping', type=float, default=1e-1, metavar='G',
                    help='damping (default: 1e-1)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 1)')
parser.add_argument('--batch-size', type=int, default=15000, metavar='N',
                    help='random seed (default: 1)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()

env = gym.make(args.env_name)

num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

env.seed(args.seed)
torch.manual_seed(args.seed)

policy_net = Policy(num_inputs, num_actions)
value_net = Value(num_inputs)

def select_action(state):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action

def update_params(batch):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net(Variable(states))

    returns = torch.Tensor(actions.size(0),1)
    deltas = torch.Tensor(actions.size(0),1)
    advantages = torch.Tensor(actions.size(0),1)

    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]

        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]

    targets = Variable(returns)

    # Original code uses the same LBFGS to optimize the value loss
    def get_value_loss(flat_params):
        set_flat_params_to(value_net, torch.Tensor(flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)

        values_ = value_net(Variable(states))

        value_loss = (values_ - targets).pow(2).mean()

        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * args.l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())

    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
    set_flat_params_to(value_net, torch.Tensor(flat_params))

    advantages = (advantages - advantages.mean()) / advantages.std()

    action_means, action_log_stds, action_stds = policy_net(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net(Variable(states))
                
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()


    def get_kl():
        mean1, log_std1, std1 = policy_net(Variable(states))

        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)

for i_episode in count(1):
    memory = Memory()

    num_steps = 0
    reward_batch = 0
    num_episodes = 0
    while num_steps < args.batch_size:
        state = env.reset()
        state = running_state(state)

        reward_sum = 0
        for t in range(10000): # Don't infinite loop while learning
            action = select_action(state)
            action = action.data[0].numpy()
            next_state, reward, done, _ = env.step(action)
            reward_sum += reward

            next_state = running_state(next_state)

            mask = 1
            if done:
                mask = 0

            memory.push(state, np.array([action]), mask, next_state, reward)

            if args.render:
                env.render()
            if done:
                break

            state = next_state
        num_steps += (t-1)
        num_episodes += 1
        reward_batch += reward_sum

    reward_batch /= num_episodes
    batch = memory.sample()
    update_params(batch)

    if i_episode % args.log_interval == 0:
        print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
            i_episode, reward_sum, reward_batch))
import numpy as np

import torch
from torch.autograd import Variable
from utils import *


def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x


def linesearch(model,
               f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1):
    fval = f(True).data
    print("fval before", fval.item())
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(True).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve
        print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            print("fval after", newfval.item())
            return True, xnew
    return False, x


def trpo_step(model, get_loss, get_kl, max_kl, damping):
    loss = get_loss()
    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

    def Fvp(v):
        kl = get_kl()
        kl = kl.mean()

        grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * Variable(v)).sum()
        grads = torch.autograd.grad(kl_v, model.parameters())
        flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data

        return flat_grad_grad_kl + v * damping

    stepdir = conjugate_gradients(Fvp, -loss_grad, 10)

    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]

    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
    print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))

    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep,
                                     neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)

    return loss

完成代码见附件 

     🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。


[Notice]  注意事项

        由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。


4.TRPO 的优势

  1. 更新稳定性
    TRPO 限制策略更新幅度,避免因更新过大导致的性能崩溃。

  2. 适用范围广泛
    TRPO 在高维连续动作空间中表现出色,适用于机器人控制等复杂任务。

  3. 高效性
    TRPO 在信赖域内寻找最优更新方向,平衡了稳定性和优化效率。


5.TRPO 的应用场景

1. 机器人控制

在动态环境中优化机器人动作,例如步态优化和抓取任务。

2. 游戏 AI

训练高维连续动作的游戏智能体,例如模拟物理环境中的策略优化。

3. 资源调度

优化任务分配和资源调度问题,例如云计算中的负载平衡。


6.结论

        TRPO 是强化学习领域的一个重要里程碑,其通过信赖域限制优化策略更新的方式,在稳定性和性能提升方面提供了良好的平衡。尽管 TRPO 在实践中因其计算复杂性被更新的算法(如 PPO)部分取代,但其核心思想仍然是深度强化学习发展的重要基础。

参考文献:Trust Region Policy Optimization

 更多自监督强化学习文章,请前往:【强化学习(RL)】专栏


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨


http://www.kler.cn/a/416874.html

相关文章:

  • 嵌入式QT学习第3天:UI设计器的简单使用
  • 说说Elasticsearch拼写纠错是如何实现的?
  • ipmitool使用详解(三)-解决各种dell、hp服务器无法ipmitool连接问题
  • linux centos nginx编译安装
  • 网络安全实践
  • ubuntu20配置mysql注意事项
  • python3 自动更新的缓存类
  • 多类别的大豆叶病识别模型复现
  • Flutter:页面滚动
  • 软件无线电(SDR)的架构及相关术语
  • 软通动力携子公司鸿湖万联、软通教育助阵首届鸿蒙生态大会成功举办
  • 数据结构——排序算法第一幕(插入排序:直接插入排序、希尔排序 选择排序:直接选择排序,堆排序)超详细!!!!
  • 40分钟学 Go 语言高并发:服务性能调优实战
  • nginx搭建直播推流服务
  • PHP和GD库如何根据像素绘制图形
  • 小车AI视觉交互--1.颜色追踪
  • 一个Python脚本
  • 网络安全开源组件
  • 用堆求解最小可用ID问题
  • C++ 之弦上舞:string 类与多样字符串操作的优雅旋律
  • 面向数字音视频的网络与操作系统技术研讨会 征稿通知
  • Qt 项目中同时使用 CMAKE_AUTOUIC 和 UiTools 的注意事项
  • 泷羽Sec-星河飞雪-BurpSuite之解码、日志、对比模块基础使用
  • 频繁发生Full GC的原因有哪些?如何避免发生Full GC
  • vue3创建
  • 使用PyQt5开发一个GUI程序的实例演示