【RL Base】强化学习:信赖域策略优化(TRPO)算法
📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:
【强化学习】(51)---《信赖域策略优化(TRPO)算法》
信赖域策略优化(TRPO)算法
目录
1.信赖域策略优化
2.什么是 TRPO?
3.TRPO 的公式解析
1. 策略优化目标
2. KL 散度约束
3. Fisher 信息矩阵
4. 二次约束优化
[Python] TRPO算法实现
1.TRPO算法伪代码
2.TRPO算法pytorch实现代码
[Notice] 注意事项
4.TRPO 的优势
5.TRPO 的应用场景
6.结论
1.信赖域策略优化
在强化学习(RL)领域,如何稳定地优化策略是一个核心挑战。2015 年,由 John Schulman 等人提出的信赖域策略优化(Trust Region Policy Optimization, TRPO)算法为这一问题提供了优雅的解决方案。TRPO 通过限制策略更新的幅度,避免了策略更新过大导致的不稳定问题,是强化学习中经典的策略优化方法之一。
2.什么是 TRPO?
TRPO 是一种基于策略梯度的优化算法,其目标是通过限制新策略和旧策略之间的差异来确保训练的稳定性。TRPO 在高维、连续动作空间中表现尤为出色,尤其适用于机器人控制、游戏 AI 等领域。
基本思想
TRPO 的核心目标是寻找新策略 ,使得累积奖励最大化,同时限制新旧策略之间的变化幅度。数学形式化如下:
其中:
- : 旧策略。
- : 新策略。
- : 优势函数,衡量动作 相对于状态的优劣。
TRPO 在优化过程中通过限制新旧策略之间的 KL 散度来实现稳定更新:
3.TRPO 的公式解析
1. 策略优化目标
TRPO 的优化目标是:
其中:
- : 概率比率,衡量新旧策略的变化。
- : 优势函数,用于评价动作的相对收益。
2. KL 散度约束
通过添加 KL 散度约束,TRPO 控制新策略与旧策略的差异:
- : 衡量新旧策略概率分布的差异。
- : 预设的信赖域阈值。
3. Fisher 信息矩阵
KL 散度的二阶导数可以通过 Fisher 信息矩阵近似表示:
4. 二次约束优化
TRPO 的优化过程最终转化为一个带二次约束的优化问题:
- : 策略梯度。
- : Fisher 信息矩阵,用于表示 KL 散度的变化。
通过拉格朗日乘数法,TRPO 可以高效求解上述问题。
[Python] TRPO算法实现
1.TRPO算法伪代码
"""《TRPO 的实现流程》
时间:2024.11
作者:不去幼儿园
"""
初始化策略参数 θ 和目标网络
for 迭代轮次 do
1. 使用策略 π_θ 采样轨迹 (s, a, r)
2. 计算优势函数 A(s, a) (可以用 GAE 方法)
3. 计算策略梯度 g = ∇_θ L(θ)
4. 估计 Fisher 信息矩阵 H
5. 通过优化问题 max δθ^T g, subject to δθ^T H δθ ≤ δ 更新策略
6. 更新目标网络
end for
2.TRPO算法pytorch实现代码
import argparse
from itertools import count
import gym
import scipy.optimize
import torch
from models import *
from replay_memory import Memory
from running_state import ZFilter
from torch.autograd import Variable
from trpo import trpo_step
from utils import *
torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True
torch.set_default_tensor_type('torch.DoubleTensor')
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.995, metavar='G',
help='discount factor (default: 0.995)')
parser.add_argument('--env-name', default="Reacher-v1", metavar='G',
help='name of the environment to run')
parser.add_argument('--tau', type=float, default=0.97, metavar='G',
help='gae (default: 0.97)')
parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G',
help='l2 regularization regression (default: 1e-3)')
parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G',
help='max kl value (default: 1e-2)')
parser.add_argument('--damping', type=float, default=1e-1, metavar='G',
help='damping (default: 1e-1)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 1)')
parser.add_argument('--batch-size', type=int, default=15000, metavar='N',
help='random seed (default: 1)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
help='interval between training status logs (default: 10)')
args = parser.parse_args()
env = gym.make(args.env_name)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
env.seed(args.seed)
torch.manual_seed(args.seed)
policy_net = Policy(num_inputs, num_actions)
value_net = Value(num_inputs)
def select_action(state):
state = torch.from_numpy(state).unsqueeze(0)
action_mean, _, action_std = policy_net(Variable(state))
action = torch.normal(action_mean, action_std)
return action
def update_params(batch):
rewards = torch.Tensor(batch.reward)
masks = torch.Tensor(batch.mask)
actions = torch.Tensor(np.concatenate(batch.action, 0))
states = torch.Tensor(batch.state)
values = value_net(Variable(states))
returns = torch.Tensor(actions.size(0),1)
deltas = torch.Tensor(actions.size(0),1)
advantages = torch.Tensor(actions.size(0),1)
prev_return = 0
prev_value = 0
prev_advantage = 0
for i in reversed(range(rewards.size(0))):
returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i]
advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]
prev_return = returns[i, 0]
prev_value = values.data[i, 0]
prev_advantage = advantages[i, 0]
targets = Variable(returns)
# Original code uses the same LBFGS to optimize the value loss
def get_value_loss(flat_params):
set_flat_params_to(value_net, torch.Tensor(flat_params))
for param in value_net.parameters():
if param.grad is not None:
param.grad.data.fill_(0)
values_ = value_net(Variable(states))
value_loss = (values_ - targets).pow(2).mean()
# weight decay
for param in value_net.parameters():
value_loss += param.pow(2).sum() * args.l2_reg
value_loss.backward()
return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())
flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
set_flat_params_to(value_net, torch.Tensor(flat_params))
advantages = (advantages - advantages.mean()) / advantages.std()
action_means, action_log_stds, action_stds = policy_net(Variable(states))
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()
def get_loss(volatile=False):
if volatile:
with torch.no_grad():
action_means, action_log_stds, action_stds = policy_net(Variable(states))
else:
action_means, action_log_stds, action_stds = policy_net(Variable(states))
log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
return action_loss.mean()
def get_kl():
mean1, log_std1, std1 = policy_net(Variable(states))
mean0 = Variable(mean1.data)
log_std0 = Variable(log_std1.data)
std0 = Variable(std1.data)
kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
return kl.sum(1, keepdim=True)
trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)
running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)
for i_episode in count(1):
memory = Memory()
num_steps = 0
reward_batch = 0
num_episodes = 0
while num_steps < args.batch_size:
state = env.reset()
state = running_state(state)
reward_sum = 0
for t in range(10000): # Don't infinite loop while learning
action = select_action(state)
action = action.data[0].numpy()
next_state, reward, done, _ = env.step(action)
reward_sum += reward
next_state = running_state(next_state)
mask = 1
if done:
mask = 0
memory.push(state, np.array([action]), mask, next_state, reward)
if args.render:
env.render()
if done:
break
state = next_state
num_steps += (t-1)
num_episodes += 1
reward_batch += reward_sum
reward_batch /= num_episodes
batch = memory.sample()
update_params(batch)
if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
i_episode, reward_sum, reward_batch))
import numpy as np
import torch
from torch.autograd import Variable
from utils import *
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
x = torch.zeros(b.size())
r = b.clone()
p = b.clone()
rdotr = torch.dot(r, r)
for i in range(nsteps):
_Avp = Avp(p)
alpha = rdotr / torch.dot(p, _Avp)
x += alpha * p
r -= alpha * _Avp
new_rdotr = torch.dot(r, r)
betta = new_rdotr / rdotr
p = r + betta * p
rdotr = new_rdotr
if rdotr < residual_tol:
break
return x
def linesearch(model,
f,
x,
fullstep,
expected_improve_rate,
max_backtracks=10,
accept_ratio=.1):
fval = f(True).data
print("fval before", fval.item())
for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
xnew = x + stepfrac * fullstep
set_flat_params_to(model, xnew)
newfval = f(True).data
actual_improve = fval - newfval
expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve
print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())
if ratio.item() > accept_ratio and actual_improve.item() > 0:
print("fval after", newfval.item())
return True, xnew
return False, x
def trpo_step(model, get_loss, get_kl, max_kl, damping):
loss = get_loss()
grads = torch.autograd.grad(loss, model.parameters())
loss_grad = torch.cat([grad.view(-1) for grad in grads]).data
def Fvp(v):
kl = get_kl()
kl = kl.mean()
grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
kl_v = (flat_grad_kl * Variable(v)).sum()
grads = torch.autograd.grad(kl_v, model.parameters())
flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data
return flat_grad_grad_kl + v * damping
stepdir = conjugate_gradients(Fvp, -loss_grad, 10)
shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
lm = torch.sqrt(shs / max_kl)
fullstep = stepdir / lm[0]
neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))
prev_params = get_flat_params_from(model)
success, new_params = linesearch(model, get_loss, prev_params, fullstep,
neggdotstepdir / lm[0])
set_flat_params_to(model, new_params)
return loss
完成代码见附件
🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。
[Notice] 注意事项
由于博文主要为了介绍相关算法的原理和应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。
4.TRPO 的优势
-
更新稳定性
TRPO 限制策略更新幅度,避免因更新过大导致的性能崩溃。 -
适用范围广泛
TRPO 在高维连续动作空间中表现出色,适用于机器人控制等复杂任务。 -
高效性
TRPO 在信赖域内寻找最优更新方向,平衡了稳定性和优化效率。
5.TRPO 的应用场景
1. 机器人控制
在动态环境中优化机器人动作,例如步态优化和抓取任务。
2. 游戏 AI
训练高维连续动作的游戏智能体,例如模拟物理环境中的策略优化。
3. 资源调度
优化任务分配和资源调度问题,例如云计算中的负载平衡。
6.结论
TRPO 是强化学习领域的一个重要里程碑,其通过信赖域限制优化策略更新的方式,在稳定性和性能提升方面提供了良好的平衡。尽管 TRPO 在实践中因其计算复杂性被更新的算法(如 PPO)部分取代,但其核心思想仍然是深度强化学习发展的重要基础。
参考文献:Trust Region Policy Optimization
更多自监督强化学习文章,请前往:【强化学习(RL)】专栏
博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨