基于panda-gym上手stable-baselines3强化学习框架
基于panda-gym上手stable-baselines3强化学习框架
- 一、入门流程
- 1.1 文档
- 1.2 创建conda python3.9环境
- 1.3 安装sb3、panda-gym
- 1.4 训练
- 1.5 查看结果
- 二、panda-gym结构分析
一、入门流程
1.1 文档
stable-baselines3文档: stable-baselines3
baselines3-zoo文档: baselines3-zoo
panda-gym文档: panda-gym
robopal文档: robopal
1.2 创建conda python3.9环境
conda create -n panda-gym python=3.9
conda activate panda-gym
1.3 安装sb3、panda-gym
pip install stable-baselines3
pip install panda-gym
从源码安装panda-gym:
git clone https://github.com/qgallouedec/panda-gym.git
pip install -e panda-gym
- 验证sb3:
import gymnasium as gym
from stable_baselines3 import A2C
env = gym.make("CartPole-v1", render_mode="rgb_array")
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
# VecEnv resets automatically
# if done:
# obs = vec_env.reset()
- 验证panda-gym:
import gymnasium as gym
import panda_gym
env = gym.make('PandaReach-v3', render_mode="human")
observation, info = env.reset()
for _ in range(1000):
action = env.action_space.sample() # random action
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
observation, info = env.reset()
1.4 训练
import gymnasium as gym
import panda_gym
from stable_baselines3 import DDPG, HerReplayBuffer, TD3, SAC
env = gym.make("PandaReach-v3")
log_dir = './panda_reach_v3_tensorboard/'
model = DDPG(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer,
verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("ddpg_panda_reach_v3")
model = TD3(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer,
verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("td3_panda_reach_v3")
model = SAC(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer,
verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("sac_panda_reach_v3")
1.5 查看结果
- 安装tensorboard查看训练过程
pip install tensorboard
tensorboard --logdir panda_reach_v3_tensorboard
- 查看模型运行结果
import gymnasium as gym
import panda_gym
from stable_baselines3 import DDPG, HerReplayBuffer, TD3, SAC
import time
env = gym.make("PandaReach-v3", render_mode="human")
model = DDPG.load("ddpg_panda_reach_v3", env=env)
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, done, info = vec_env.step(action)
vec_env.render()
time.sleep(0.5)
if done:
print('Done')
obs = vec_env.reset()
- 转换成动画
import gymnasium as gym
import panda_gym
env = gym.make("PandaPickAndPlace-v3", render_mode="rgb_array")
observation, info = env.reset()
images = [env.render()]
for _ in range(200):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
images.append(env.render())
if terminated or truncated:
observation, info = env.reset()
images.append(env.render())
env.close()
%pip install numpngw
from numpngw import write_apng
write_apng("anim.png", images, delay=40) # real-time rendering = 40 ms between frames
from IPython.display import Image
Image(filename="anim.png")
二、panda-gym结构分析
panda-gym 是一个基于PyBullet物理引擎和Gymnasium环境的机器人学习框架,专为Franka Emika Panda机器人设计的一系列环境。可以与sb3结合完成强化学习训练,下面是项目的主要目录结构概述:
- docs: 包含项目的官方文档,指导用户如何开始和进阶。
- examples: 提供示例代码,帮助用户快速上手如何创建和使用环境中不同的任务。
- panda_gym: 核心代码库,定义了所有与环境相关的类和函数,如环境实现(envs)、特定任务(tasks)等子目录。
- test: 单元测试相关文件,用于确保代码的健壮性。
- .gitignore: Git忽略文件,指定不需要纳入版本控制的文件或目录。
- LICENSE: 项目的授权协议,采用MIT License。
- Makefile, readthedocs.yml: 构建和文档生成相关的配置文件。
- README.md: 项目简介,安装指南和快速入门步骤。
- setup.py: Python包的安装脚本。
基类:
pybullet物理引擎函数:
机器人(urdf)模型生成和基础控制:
任务基类:
RobotTaskEnv(继承自gym.env)包含机器人和任务的环境基类:
- panda_gym.envs.panda_tasks
定义了具体的环境类,继承RobotTaskEnv,RobotTaskEnv继承gym.env,实现了 init、step、reset、render 等方法,用于在 Gym 中注册自定义环境。
- panda_gym.init
在 Gym 中注册自定义环境
import os
from gymnasium.envs.registration import register
with open(os.path.join(os.path.dirname(__file__), "version.txt"), "r") as file_handler:
__version__ = file_handler.read().strip()
ENV_IDS = []
for task in ["Reach", "Slide", "Push", "PickAndPlace", "Stack", "Flip"]:
for reward_type in ["sparse", "dense"]:
for control_type in ["ee", "joints"]:
reward_suffix = "Dense" if reward_type == "dense" else ""
control_suffix = "Joints" if control_type == "joints" else ""
env_id = f"Panda{task}{control_suffix}{reward_suffix}-v3"
register(
id=env_id,
entry_point=f"panda_gym.envs:Panda{task}Env",
kwargs={"reward_type": reward_type, "control_type": control_type},
max_episode_steps=100 if task == "Stack" else 50,
)
ENV_IDS.append(env_id)
在 Gym 中注册自定义环境
import gymnasium as gym
import panda_gym
env = gym.make("PandaReach-v3")
待更新…