【强化学习】Stable-Baselines3学习笔记
【强化学习】Stable-Baselines3学习笔记
- Stable-Baselines3是什么
- 安装
- Example
- Reinforcement Learning Tips and Tricks
- VecEnv相关
Stable-Baselines3是什么
- Stable Baselines3(简称SB3)是一套基于PyTorch实现的强化学习算法的可靠工具集
- 旨在为研究社区和工业界提供易于复制、优化和构建新项目的强化学习算法实现
- 官方文档链接:Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations
Stable-Baselines的一些特点:
Q:RL Baselines3 Zoo、SB3 Contrib和SBX (SB3 + Jax)与Stable Baselines3的关系是什么?
A:
- RL Baselines3 Zoo:RL Baselines3 Zoo是一个基于Stable Baselines3的训练框架,提供了训练、评估、调优超参数、绘图及视频录制的脚本。它的目标是提供一个简单的接口来训练和使用RL代理,同时为每个环境和算法提供调优的超参数
- SB3 Contrib:SB3 Contrib是一个包含社区贡献的强化学习算法的仓库,提供了一些实验性的算法和功能。这使得主库SB3能够保持稳定和紧凑,同时通过SB3 Contrib提供最新的算法
- SBX (SB3 + Jax):Stable Baselines Jax (SBX)是Stable Baselines3在Jax上的概念验证版本,提供了一些最新的强化学习算法, 它与SB3相比提供了较少的功能,但在某些情况下可以提供更高的性能,速度可能快达20倍。 SBX遵循SB3的API,因此与RL Zoo兼容
这三个项目都是Stable Baselines3生态系统的一部分,它们共同提供了一个全面的工具集,用于强化学习的研究和开发。SB3提供了核心的强化学习算法实现,而RL Baselines3 Zoo提供了一个训练和评估这些算法的框架。SB3 Contrib则作为实验性功能的扩展库,SBX则探索了使用Jax来加速这些算法的可能性
安装
- Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
- Windows的要求:Python 3.8或以上
- 安装命令:
#该命令将会安装 Stable Baselines3以及一些依赖项 如Tensorboard, OpenCV or ale-py
pip install stable-baselines3[extra]
#该命令仅安装 Stable Baselines3 的核心包
pip install stable-baselines3
Example
- 官方示例代码:
import gymnasium as gym
from stable_baselines3 import A2C
env = gym.make("CartPole-v1", render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
与直接使用gymnasium环境的不同之处
gymnasium.make
之后,还需要创建vec_envvec_env = model.get_env()
- 环境的reset使用vec_env.reset()
- VecEnv仅需在训练开始时reset,训练中无需手动reset,具体请看上述代码中最后的注释部分
Reinforcement Learning Tips and Tricks
- 强化学习与其他机器学习方法不同之处:训练的数据由智能体本身收集Reinforcement Learning differs from other machine learning methods in several ways. The data used to train the agent is collected through interactions with the environment by the agent itself(compared to supervised learning where you have a fixed dataset for instance).
- 这种依赖会导致恶性循环:如果代理收集到质量较差的数据(例如,没有奖励的轨迹),那么它就不会改进并继续积累错误的轨迹。This dependence can lead to vicious circle: if the agent collects poor quality data (e.g., trajectories with no rewards), then it will not improve and continue to amass bad trajectories.
VecEnv相关
- stable-baselines使用矢量化环境(VecEnv)
- VecEnv允许并行地在一个环境中的多个实例上运行,这样可以显著提高数据收集和训练的效率
- VecEnv支持批量操作(允许模型一次从多个环境实例中学习),可以一次性对所有环境实例执行相同的动作,然后同时获取所有实例的观测、奖励和完成状态
- 在VecEnv中,当一个环境实例完成(即done为True)时,它会自动重置