当前位置: 首页 > article >正文

基于panda-gym上手stable-baselines3强化学习框架

基于panda-gym上手stable-baselines3强化学习框架

  • 一、入门流程
    • 1.1 文档
    • 1.2 创建conda python3.9环境
    • 1.3 安装sb3、panda-gym
    • 1.4 训练
    • 1.5 查看结果
  • 二、panda-gym结构分析

一、入门流程

1.1 文档

stable-baselines3文档: stable-baselines3
baselines3-zoo文档: baselines3-zoo

panda-gym文档: panda-gym
robopal文档: robopal

1.2 创建conda python3.9环境

conda create -n panda-gym python=3.9
conda activate panda-gym

1.3 安装sb3、panda-gym

pip install stable-baselines3
pip install panda-gym

从源码安装panda-gym:

git clone https://github.com/qgallouedec/panda-gym.git
pip install -e panda-gym
  • 验证sb3:
import gymnasium as gym

from stable_baselines3 import A2C

env = gym.make("CartPole-v1", render_mode="rgb_array")

model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render("human")
    # VecEnv resets automatically
    # if done:
    #   obs = vec_env.reset() 
  • 验证panda-gym:
import gymnasium as gym
import panda_gym

env = gym.make('PandaReach-v3', render_mode="human")

observation, info = env.reset()

for _ in range(1000):
    action = env.action_space.sample() # random action
    observation, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
        observation, info = env.reset()

1.4 训练

import gymnasium as gym
import panda_gym
from stable_baselines3 import DDPG, HerReplayBuffer, TD3, SAC

env = gym.make("PandaReach-v3")
log_dir = './panda_reach_v3_tensorboard/'

model = DDPG(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer, 
             verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("ddpg_panda_reach_v3")

model = TD3(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer,
            verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("td3_panda_reach_v3")

model = SAC(policy="MultiInputPolicy", env=env, replay_buffer_class=HerReplayBuffer,
            verbose=1, tensorboard_log=log_dir)
model.learn(30_000)
model.save("sac_panda_reach_v3")

1.5 查看结果

  • 安装tensorboard查看训练过程
pip install tensorboard
tensorboard --logdir panda_reach_v3_tensorboard
  • 查看模型运行结果
import gymnasium as gym
import panda_gym
from stable_baselines3 import DDPG, HerReplayBuffer, TD3, SAC
import time

env = gym.make("PandaReach-v3", render_mode="human")
model = DDPG.load("ddpg_panda_reach_v3", env=env)
vec_env = model.get_env()
obs = vec_env.reset()

for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, done, info = vec_env.step(action)
    vec_env.render()
    time.sleep(0.5)
    if done:
        print('Done')
        obs = vec_env.reset()
  • 转换成动画
import gymnasium as gym
import panda_gym

env = gym.make("PandaPickAndPlace-v3", render_mode="rgb_array")
observation, info = env.reset()

images = [env.render()]
for _ in range(200):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    images.append(env.render())

    if terminated or truncated:
        observation, info = env.reset()
        images.append(env.render())

env.close()

%pip install numpngw
from numpngw import write_apng

write_apng("anim.png", images, delay=40)  # real-time rendering = 40 ms between frames

from IPython.display import Image

Image(filename="anim.png")

二、panda-gym结构分析

panda-gym 是一个基于PyBullet物理引擎和Gymnasium环境的机器人学习框架,专为Franka Emika Panda机器人设计的一系列环境。可以与sb3结合完成强化学习训练,下面是项目的主要目录结构概述:

  • docs: 包含项目的官方文档,指导用户如何开始和进阶。
  • examples: 提供示例代码,帮助用户快速上手如何创建和使用环境中不同的任务。
  • panda_gym: 核心代码库,定义了所有与环境相关的类和函数,如环境实现(envs)、特定任务(tasks)等子目录。
  • test: 单元测试相关文件,用于确保代码的健壮性。
  • .gitignore: Git忽略文件,指定不需要纳入版本控制的文件或目录。
  • LICENSE: 项目的授权协议,采用MIT License。
  • Makefile, readthedocs.yml: 构建和文档生成相关的配置文件。
  • README.md: 项目简介,安装指南和快速入门步骤。
  • setup.py: Python包的安装脚本。

基类:
pybullet物理引擎函数:
在这里插入图片描述
机器人(urdf)模型生成和基础控制:
在这里插入图片描述
任务基类:
在这里插入图片描述
RobotTaskEnv(继承自gym.env)包含机器人和任务的环境基类:
在这里插入图片描述


  • panda_gym.envs.panda_tasks
    定义了具体的环境类,继承RobotTaskEnv,RobotTaskEnv继承gym.env,实现了 init、step、reset、render 等方法,用于在 Gym 中注册自定义环境。

  • panda_gym.init

在 Gym 中注册自定义环境

import os

from gymnasium.envs.registration import register

with open(os.path.join(os.path.dirname(__file__), "version.txt"), "r") as file_handler:
    __version__ = file_handler.read().strip()

ENV_IDS = []

for task in ["Reach", "Slide", "Push", "PickAndPlace", "Stack", "Flip"]:
    for reward_type in ["sparse", "dense"]:
        for control_type in ["ee", "joints"]:
            reward_suffix = "Dense" if reward_type == "dense" else ""
            control_suffix = "Joints" if control_type == "joints" else ""
            env_id = f"Panda{task}{control_suffix}{reward_suffix}-v3"

            register(
                id=env_id,
                entry_point=f"panda_gym.envs:Panda{task}Env",
                kwargs={"reward_type": reward_type, "control_type": control_type},
                max_episode_steps=100 if task == "Stack" else 50,
            )

            ENV_IDS.append(env_id)

在 Gym 中注册自定义环境

import gymnasium as gym
import panda_gym

env = gym.make("PandaReach-v3")

待更新…


http://www.kler.cn/a/354713.html

相关文章:

  • 简单说说 spring 是如何处理循环依赖问题的(源码解析)
  • Unity 2D角色的跳跃与二段跳示例
  • Springboot 整合 Java DL4J 实现物流仓库货物分类
  • 论文翻译 | LARGE LANGUAGE MODELS ARE HUMAN-LEVELPROMPT ENGINEERS
  • 计算机网络自顶向下(4)---应用层HTTP协议
  • Nginx在Windows Server下的启动脚本
  • 20201017-【C、C++】跳动的爱心
  • Git推送被拒
  • exists在sql中的妙用
  • Linux笔记---vim的使用
  • OpenHarmony 入门——ArkUI 自定义组件内同步的装饰器@State小结(二)
  • vue使用gdal-async获取tif文件的缩略图
  • 【系统架构设计师】案例分析考点情况分析和解题技巧(包括2009~2024年考点详情)
  • 详解UDP-TCP网络编程
  • 【C#生态园】提升数据处理效率:C#中多款数据清洗库全面解析
  • 【wpf】07 后端验证及令牌码获取步骤
  • [旧日谈]关于Qt的刷新事件频率,以及我们在Qt的框架上做实时的绘制操作时我们该关心什么。
  • 关于FFmpeg【使用方法、常见问题、解决方案等】
  • jmeter 对 dubbo 接口测试是怎么实现的?有哪几个步骤
  • 我谈结构自相似性SSIM——实质度量的是什么?