当前位置: 首页 > article >正文

Python-代码阅读-epsilon-greedy策略函数

1.代码 

def epsilon_greedy_policy(qnet, num_actions):
    def policy_fn(sess, observation, epsilon):
        # epsilon-greedy策略函数
        # 输入参数:
        #     qnet: Q网络模型,用于预测Q值
        #     num_actions: 动作空间的数量
        #     sess: TensorFlow会话,用于执行模型预测
        #     observation: 当前观测值
        #     epsilon: ε值,表示探索概率
        
        if (np.random.rand() < epsilon):  
            # 探索:等概率地选择所有动作
            A = np.ones(num_actions, dtype=float) / float(num_actions)
        else:
            # 利用:选择具有最大Q值的动作
            q_values = qnet.predict(sess, np.expand_dims(observation, 0))[0]
            max_Q_action = np.argmax(q_values)
            A = np.zeros(num_actions, dtype=float)
            A[max_Q_action] = 1.0 
        return A
    return policy_fn

2.代码阅读

该函数实现了ε-greedy策略,根据当前的Q网络模型(qnet)、动作空间的数量(num_actions)、当前观测值(observation)和探索概率ε(epsilon)选择动作。

当随机生成的随机数小于ε时,选择等概率地选择所有动作(探索),否则根据Q网络模型预测的Q值选择具有最大Q值的动作(利用)。返回一个概率分布,表示在当前状态下选择各个动作的概率。

2.1 np.ones(num_actions, dtype=float) / float(num_actions)

A = np.ones(num_actions, dtype=float) / float(num_actions)

A = np.ones(num_actions, dtype=float) / float(num_actions)这行代码的作用是创建一个包含num_actions个元素的一维数组A,每个元素的初始值都为1.0,并且将数组中的所有元素除以num_actions,从而得到一个等概率的概率分布。

具体而言,np.ones(num_actions, dtype=float)创建了一个由num_actions个元素组成的一维数组,每个元素的值都为1.0,/ float(num_actions)将数组中的每个元素除以num_actions,从而得到一个等概率的概率分布。最终,将这个概率分布赋值给数组A,表示在探索阶段,每个动作被选择的概率相等。

2.2 qnet.predict()

q_values = qnet.predict(sess, np.expand_dims(observation, 0))[0]

q_values = qnet.predict(sess, np.expand_dims(observation, 0))[0] 这行代码的作用是使用qnet模型通过输入observation进行预测,并获取预测结果中的Q值(动作值函数)。

具体而言,代码中使用np.expand_dims(observation, 0)observation转换为一个形状为(1, observation_shape)的数组,其中observation_shapeobservation的形状。这样做是为了将observation作为一个样本输入到qnet模型中进行预测。

接着,qnet.predict(sess, np.expand_dims(observation, 0))调用qnet模型的predict方法,传入sess作为会话对象和转换后的observation作为输入,得到一个包含Q值的数组。

最后,通过[0]取得数组中的第一个元素,即Q值数组,赋值给q_values,表示预测得到的Q值。这样,q_values就包含了模型对当前observation的每个动作的Q值估计。

np.expand_dims(observation, 0)

np.expand_dims(observation, 0) 这行代码的作用是将observation数组在第0维(最前面)添加一个维度。

具体而言,np.expand_dims(observation, 0)会返回一个新的数组,其中observation数组会在第0维添加一个维度。这个新的数组将具有形状(1, observation_shape),其中observation_shapeobservation数组的形状。这样做是为了将observation作为一个单独的样本输入到模型中进行预测。

例如,如果observation原本的形状是(observation_shape,),则经过np.expand_dims(observation, 0)处理后,新的数组形状将变为(1, observation_shape),其中第0维有一个大小为1的维度。这样的处理在某些情况下可以确保输入数据的维度与模型期望的输入维度一致。

2.3 max_Q_action = np.argmax(q_values)

max_Q_action = np.argmax(q_values)

max_Q_action = np.argmax(q_values) 这行代码的作用是找到Q值数组 q_values 中的最大值,并返回其对应的索引,即表示最优动作的索引

具体而言,np.argmax(q_values) 调用 np.argmax 函数,传入 q_values 数组作为参数。np.argmax 函数会返回 q_values 数组中的最大值所在的索引。这个索引表示在当前状态下,模型认为具有最高Q值(即最优动作)的动作。

将返回的最优动作索引赋值给 max_Q_action,以便后续在构建 epsilon-greedy 策略时使用。

2.4 np.zeros(num_actions, dtype=float)

A = np.zeros(num_actions, dtype=float)

A = np.zeros(num_actions, dtype=float) 这行代码的作用是创建一个形状为 (num_actions,) 的全零数组,并指定数据类型为 float

具体而言,np.zeros(num_actions, dtype=float) 调用 np.zeros 函数,传入 num_actions 参数作为数组的长度,dtype=float 参数指定数组的数据类型为 float。函数将创建一个长度为 num_actions 的全零数组,并将其数据类型设定为 float

这个数组 A 用于存储 epsilon-greedy 策略中各个动作的概率。在策略中,对于具有最高 Q 值的动作,其概率会设置为1,表示以确定性选择最优动作;而对于其他动作,其概率会设置为0,表示不选取这些动作。

2.5 A[max_Q_action] = 1.0

A[max_Q_action] = 1.0

A[max_Q_action] = 1.0 这行代码的作用是将 A 数组中索引为 max_Q_action 的位置的元素值设置为1.0。

具体而言,A[max_Q_action] 是访问 A 数组中索引为 max_Q_action 的位置的元素值。将其赋值为1.0,表示在 epsilon-greedy 策略中,最优动作的概率被设置为1.0,即以确定性选择最优动作。

这样,A 数组中只有最优动作的位置上的元素值为1.0,其余位置上的元素值都为0,从而实现了在策略中以确定性选择最优动作的效果。


http://www.kler.cn/a/9645.html

相关文章:

  • 微信小程序=》基础=》常见问题=》性能总结
  • [CKS] 关闭API凭据自动挂载
  • 机器学习基础02_特征工程
  • DApp开发:定制化解决方案与源码部署的一站式指南
  • 笔记 | image may have poor performance,or fail,if run via emulation
  • Java基础-组件及事件处理(下)
  • Spark大数据处理讲课笔记3.1 掌握RDD的创建
  • Leetcode.1019 链表中的下一个更大节点
  • HTTP协议详解(二)
  • 第五十五天打卡
  • Sentinel滑动时间窗限流算法原理及源码解析(下)
  • PACS系统中的三维重建技术:原理、实现与应用
  • 使用JavaScript编写第一个测试案例
  • MyBatisPlus标准数据层开发
  • 02-神经网络基础
  • 15个awk的经典实战案例
  • 【Go自学】Go语言中命名返回值函数对defer影响
  • 体育活动---英文单词
  • nacos和eureka的区别
  • 网络书店前端代码
  • 1.docker-安装及使用
  • item_history_price-获取商品历史价格信息 API接入参数及说明
  • 2023年MathorCup数模B题赛题
  • 如何自学JAVA
  • SQL Server的事务日志
  • CentOS7 内网安装mosquitto