当前位置: 首页 > article >正文

ray.rllib 入门实践-4: 构建算法

        在前面的博客 ray.rllib 入门实践-2:配置算法-CSDN博客 介绍了ray.rllib中的几种配置算法的方法,在示例代码中同步给出了构建(build)算法的方法,但是没有对构建算法的方式进行归纳介绍。

        本博客主要梳理ray.rllib中,从config生成可训练的algorithm的几种方式。

方式1 : algo = AlgorithmConfig().build()

示例 1: 

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径

## 构建算法
algo = config.build()

## 训练
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

示例 2:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置并构建算法
algo = (
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
    .build()  ## 在这里实现算法构建
)

## 训练
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

方式2: algo = Algorithm(config=config)

示例1:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0)
config = config.environment(env="CartPole-v1")

## 构建算法
algo = PPO(config=config)

## 训练
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

示例2:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results"
os.makedirs(storage_path, exist_ok=True)
config = {
    "env":"CartPole-v1",
    "env_config":{}, ## 用于传递给env的信息
    "frame_work":"torch",
    "num_gpus":0,
    "num_workers":2,
    "num_cpus_per_worker":1,
    "num_envs_per_worker":1,
    "num_gpus_per_worker":0,
    "lr":0.001,
    "model":{
        "fcnet_hiddens":[256,256,64],
        "fcnet_activation":"tanh",
        "custom_model_config":{},
        "custom_model":None},
    "output":storage_path
}

## 构建算法
algo = PPO(config=config)

## 训练
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

在这两个示例中, 分别向PPO算法中传入一个 PPOConfig 类对象 和 一个 DIct, 在PPO算法内部,都会把config转变成dict格式使用。用 PPOConfig 进行配置更利于开发者使用python的语法提示。

可运行代码:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

################### build-algo  ##########################################
## build-method-1
# config = PPOConfig()\
#         .rollouts(num_rollout_workers = 2)\
#         .resources(num_gpus=0)\
#         .environment(env="CartPole-v1")
# algo = config.build()

## build-method-2
# algo = (
#     PPOConfig()
#     .rollouts(num_rollout_workers=1)
#     .resources(num_gpus=0)
#     .environment(env="CartPole-v1")
#     .build()
# )

## build-method-3 : 推荐 ,推荐原因: 有代码提示。
# storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
# config = PPOConfig()
# config = config.rollouts(num_rollout_workers=2)
# config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
# config = config.environment(env="CartPole-v1",env_config={})
# config.output = storage_path  ## 设置过程文件的存储路径
# algo = config.build()

## build-method-4 : 推荐
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0)
config = config.environment(env="CartPole-v1")
algo = PPO(config=config)

## build-method-5 : 不推荐
# config = PPOConfig()
# config = config.rollouts(num_rollout_workers=2)
# config = config.resources(num_gpus=0)
# algo = PPO(config=config,env="CartPole-v1") ## 这种配置env的方式即将被抛弃

## build-method-6 : error
## 此种方式会报错: Cannot set attribute (num_rollout_workers) of an already frozen AlgorithmConfig!
## 一经执行build之后,就不能再修改配置了。
# algo = PPO(env="CartPole-v1")
# algo.config.rollouts(num_rollout_workers=2)
# algo.config.resources(num_gpus=0)

## build-method-7 :error
# config = PPOConfig()
# config["rollouts"] = {"num_rollout_workers":2}
# config["resources"] = {"num_gpus":0}
# config["environment"] = {"env":"CartPole-v1"}
# algo = config.build()

## build-method-8 : 次推荐。 优点:简洁直观。 缺点: 配置时没有代码提示。
##    method-1,2,3,4,5 在从config生成algo时, build函数内在的把 PPOConfig类对象转变成以下格式。
##    rllib中直接使用以下格式创建算法。 缺点是,在配置算法时, 没有代码提示。 
##    PPOConfig 类可以认为是对以下字典的一种升级封装,更便于使用。 
# storage_path = "F:/codes/RLlib_study/ray_results/build_method_8"
# os.makedirs(storage_path, exist_ok=True)
# config = {
#     "env":"CartPole-v1",
#     "env_config":{}, ## 用于传递给env的信息
#     "frame_work":"torch",
#     "num_gpus":0,
#     "num_workers":2,
#     "num_cpus_per_worker":1,
#     "num_envs_per_worker":1,
#     "num_gpus_per_worker":0,
#     "lr":0.001,
#     "model":{
#         "fcnet_hiddens":[256,256,64],
#         "fcnet_activation":"tanh",
#         "custom_model_config":{},
#         "custom_model":None},
#     "output":storage_path
# }
# algo = PPO(config=config)

#################  print config  #######################################
## 打印配置信息: method-1, 推荐
config_dict = config.to_dict() # algo_config是一个类对象,需要转变成dict才能打印和显示
# print(config_dict)
print(pretty_print(config_dict)) ## 以更优雅的方式打印,便于阅读。

## 打印配置信息: method-2
# algo_config = algo.get_config()
# config_dict = algo_config.to_dict() # algo_config是一个类对象,需要转变成dict才能打印和显示
# print(config_dict)



###################  训练  #############################################
for i in range(3):
    result = algo.train() ## algo_train是执行一个episode的训练,返回一个result_dict, 包含了许多信息,可以去解析使用
    # print(pretty_print(result))
    print(f"episode_{i}")

    ## 保存模型
    if i % 2 == 0:
        checkpoint_dir = algo.save().checkpoint.path
        print(f"Checkpoint saved in directory {checkpoint_dir}")


http://www.kler.cn/a/518278.html

相关文章:

  • 834 数据结构(自用)
  • 跟我学C++中级篇——64位的处理
  • 高效流式大语言模型(StreamingLLM)——基于“注意力汇聚点”的突破性研究
  • Docker 国内镜像源
  • 基于 WEB 开发的在线学习系统设计与开发
  • 【Uniapp-Vue3】setTabBar设置TabBar和下拉刷新API
  • debian12使用kvm安装windows系统
  • solidity基础 -- 事件
  • 如何用数据编织、数据虚拟化与SQL-on-Hadoop打造实时、可扩展兼容的数据仓库?
  • 【python】四帧差法实现运动目标检测
  • 如何做一个C#仿Halcon Calibration插件
  • 大模型学习计划
  • python判断字符串是否存在空白、字母或数字
  • 单链表算法实战:解锁数据结构核心谜题——移除链表元素
  • 计算机网络 (54)系统安全:防火墙与入侵检测
  • 论文速读|Matrix-SSL:Matrix Information Theory for Self-Supervised Learning.ICML24
  • 机器学习11-学习路径推荐
  • Solon Cloud Gateway 开发:导引
  • 99.15 金融难点通俗解释:毛利率vs营业利润率vs净利率
  • AI画笔,绘就古今艺术星河(5/10)
  • 【Docker】私有Docker仓库的搭建
  • K8S中Service详解(三)
  • 食堂订餐小程序ssm+论文源码调试讲解
  • pytorch2.5实例教程
  • poi在word中打开本地文件
  • Cloudflare通过代理服务器绕过 CORS 限制:原理、实现场景解析