当前位置: 首页 > article >正文

【分层强化学习】Option Critic 的 CartPole-v1 的简单实例

注意:inner policy的训练算法只是基本的PG,所以训练过程极不稳定。如有需要可以自己试试调参,或者把inner policy的训练算法改成更稳定的比如PPO等方法。

import numpy as np
import torch
import torch.nn as nn

import gym

import torch.nn.functional as F

from torch.distributions.categorical import Categorical

class NN(nn.Module):

    def __init__(self, state_size, action_size, hidden_size, num_options):
        super().__init__()

        self.actors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, action_size),
                nn.Softmax(dim=-1)
            ) for _ in range(num_options)
        ])

        self.terminations = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
                nn.Sigmoid()
            ) for _ in range(num_options)
        ])

        self.critics = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, action_size),
            ) for _ in range(num_options)
        ])

    def select_option(self, state, epsilon):
        # print("change option")
        if np.random.rand() >= epsilon:

            max_value = - np.inf
            option_id = -1

            for i, (a, c) in enumerate(zip(self.actors, self.critics)):

                q = c(state)
                p = a(state)

                v = (q * p).sum(-1).item()

                if v >= max_value:
                    option_id = i
                    max_value = v

        else:
            option_id = np.random.randint(0, len(self.actors))

        return self.actors[option_id], self.terminations[option_id], option_id

if __name__ == '__main__':

    np.random.seed(0)

    episodes = 5000

    epsilon = 1.0
    discount = 0.9
    epsilon_decay = 0.995
    epsilon_min = 0.05

    training_epochs = 1

    env = gym.make('CartPole-v1')

    nn = NN(4, 2, 128, 6)
    # nn = torch.load("NN.pt")
    optimizer = torch.optim.Adam(nn.parameters(), lr=1e-2)

    max_score = 0.0

    trajectory = []

    for e in range(1, episodes + 1):

        if e % training_epochs == 0:

            trajectory = []

        score = 0.0

        state, _ = env.reset()

        option = nn.select_option(torch.tensor(state), epsilon)

        while True:

            policy = option[0](torch.tensor(state))
            action = Categorical(policy).sample()

            next_state, reward, done, _, _ = env.step(action.detach().numpy())

            score += reward

            beta = option[1](torch.tensor(next_state)).item()

            if np.random.rand() > beta:

                trajectory.append(
                    (state, action, reward, next_state, done, option[2], beta, False)
                )

            else:

                trajectory.append(
                    (state, action, reward, next_state, done, option[2], beta, True)
                )

                option = nn.select_option(torch.tensor(next_state), epsilon)

            state = next_state

            if done: break

        # start training
        if e % training_epochs == 0:
            optimizer.zero_grad()

            q_targets = []
            option_states = []
            option_advs = []
            option_next_states = []

            for state, action, reward, next_state, done, option_id, beta, option_terminal in trajectory:

                q = reward + (1 - done) * discount * (
                        (1 - beta) * (
                        nn.critics[option_id](torch.tensor(next_state)) *
                        nn.actors[option_id](torch.tensor(next_state))
                ).sum(-1).item() +
                        beta * max([
                    (
                            nn.critics[i](torch.tensor(next_state)) *
                            nn.actors[i](torch.tensor(next_state))
                    ).sum(-1).item()
                    for i in range(len(nn.critics))
                ])
                )

                q_target = nn.critics[option_id](torch.tensor(state)).detach().numpy()
                q_target[action] = q

                q_targets.append(q_target)
                option_states.append(state)

                inner_next_value = (
                        nn.critics[option_id](torch.tensor(next_state)).detach().numpy() *
                        nn.actors[option_id](torch.tensor(next_state)).detach().numpy()
                ).sum(-1).item()

                next_value = max([(
                                          nn.critics[i](torch.tensor(next_state)).detach().numpy() *
                                          nn.actors[i](torch.tensor(next_state)).detach().numpy()
                                  ).sum(-1).item() for i in range(len(nn.critics))])

                option_adv = inner_next_value - next_value

                option_advs.append(option_adv)
                option_next_states.append(next_state)

                if option_terminal:

                    option_states = torch.tensor(np.array(option_states))
                    q_targets = torch.tensor(np.array(q_targets))
                    option_advs = torch.tensor(np.array(option_advs)).view(-1, 1)
                    option_next_states = torch.tensor(np.array(option_next_states))

                    option_critic_loss = F.mse_loss(
                        nn.critics[option_id](option_states),
                        q_targets
                    )

                    actor_advs = q_targets - nn.critics[option_id](option_states).detach()

                    option_actor_loss = - (torch.log(nn.actors[option_id](option_states)) * actor_advs).mean()

                    option_terminal_loss = (nn.terminations[option_id](option_next_states) * option_advs).mean()

                    option_critic_loss.backward()
                    option_actor_loss.backward()
                    option_terminal_loss.backward()

                    q_targets = []
                    option_states = []
                    option_advs = []
                    option_next_states = []

            optimizer.step()

        if epsilon > epsilon_min:

            epsilon *= epsilon_decay

        if score > max_score:

            max_score = score

            torch.save(nn, 'NN.pt')

        print("Episode: {}/{}, Epsilon: {}, Score: {}, Max score: {}".format(e, episodes, epsilon, score, max_score))


http://www.kler.cn/a/286305.html

相关文章:

  • pytorch基于GloVe实现的词嵌入
  • 一文读懂 Faiss:开启高维向量高效检索的大门
  • 04树 + 堆 + 优先队列 + 图(D1_树(D1_基本介绍))
  • stm32控制直流电机程序
  • JAVA(SpringBoot)集成Kafka实现消息发送和接收。
  • 代码随想录算法训练营第三十九天-动态规划-337. 打家劫舍 III
  • MATLAB 地面点构建三角网(83)
  • 事务代码中加synchronized锁引发的bug
  • 5.图论.题目2
  • MySQL索引分类
  • 23. 如何使用Collections.synchronizedList()方法来创建线程安全的集合?有哪些注意事项?
  • 浅析JavaScript 堆内存及其通过 Chrome DevTools 捕获堆快照的方法
  • SQL 注入之 Oracle 注入
  • springboot在线办公小程序论文源码调试讲解
  • 学习日志8.30--防火墙NAT
  • 【awk 】如何将一个文件按照同名字段进行合并?
  • 【MySQL进阶】索引性能分析
  • 解决reCaptcha v2 Invisible:识别和参数
  • 使用MySQL划分收货地址后将数据添加到原表中
  • MySQL:多表查询
  • python网络爬虫(三)——爬虫攻防
  • Question mutiple pdf‘s using openai, pinecone, langchain
  • [pytorch] --- pytorch基础之transforms
  • Python算法L2:排序算法(详细版)
  • 前端提高Web/App/小程序开发效率的工具
  • CSS 的值与单位——WEB开发系列21