当前位置: 首页 > article >正文

24/12/9 算法笔记<强化学习> TD3

TD3是双延迟深度确定性策略梯度,TD3是DDPG的一个优化版本。

在DQN中,我们通过Critic网络估算动作A的Q值,一个Critic的评估可能会比较高,所以我们再加一个

        这相当于我们把途中的Critic的框框,一个变为两个。

        在目标网络中,我们估算出来的Q值会用min()函数求出较少值,以这个值作为更新的目标,这个目标会更新两个网络,Critic网络1和Critic网络2。

        可以理解为两个网络相互独立,只是都用同一个目标进行更新,剩余的就和DDPG一样了,过了一段时间,把学习好的网络赋值给目标网络。

在Critic中:

        只有在计算Critic的更新目标时,我们才用target network。其中包含一个Policy network,用于计算A';两个Q network,用于计算两个Q值:Q1(A')和Q2(A’)。

        Q1(A')和Q2(A’)取最小值min(Q1,Q2)将代替DDPG的Q(a') 计算更新目标,也就是说:

target = min(Q1,Q2) * gamma + r

        target将会是Q_network_1和Q_network_2两个网络的更新目标。

问题:为什么更新目标是一样的,还需要两个网络呢?

        虽然更新目标一样,两个网络会越来越趋近与和实际q值相同。但由于网络参数的初始值不一样,会导致计算出来的值有所不同。所以我们可以有空间选择较小的值取估算q值,避免q值被高估。

在Actor中:

        actor的任务,使用梯度上升的方法,寻找使得Q值最高的action。更新Q1和Q2两个网络

  

Delayed Actor Update:

        这里的Delayed是指Actor更新的delay,也就是说相对于critic可以更新多次后,actor再进行更新。举个例子:老师打分要确定些,学生才知道要学什么。

target policy smoothing regularization        

        TD3中,价值函数的更新目标每次在action上加上一个小扰动,这就是target policy smoothing regularization,为什么要加呢?

        注意这里是给actor目标网路输出的action增加noise,而不是给真正给出实际输出的action增加noise,当实际给出的actor增加noise时,是为了让actor进行更丰富的探讨,

        这里给目标网络输出的action增加noise是为了给后面的critic目标网络增加难度。

代码示例:

我们就讲一下TD3对于DDPG的优化部分的代码

首先是Critic网络,这里定义两个Critic网络用于TD3的双网络结构

class CriticNetwork(tf.keras.Model):
    def __init__(self, state_dim, action_dim):
        super(CriticNetwork, self).__init__()
        self.fc1_s = tf.keras.layers.Dense(400, activation='relu', kernel_initializer='he_normal')
        self.fc1_a = tf.keras.layers.Dense(400, activation='relu', kernel_initializer='he_normal')
        self.fc2 = tf.keras.layers.Dense(300, activation='relu', kernel_initializer='he_normal')
        self.fc3 = tf.keras.layers.Dense(1, kernel_initializer='he_normal')

    def call(self, state, action):
        s = self.fc1_s(state)
        a = self.fc1_a(action)
        x = tf.keras.layers.concatenate([s, a])
        x = self.fc2(x)
        return self.fc3(x)

TD3的智能体类中:

我们设置两个Critic网络以及对应的目标网络

        # 两个价值网络(Critic)及对应的目标网络
        self.critic_1 = CriticNetwork(self.state_dim, self.action_dim)
        self.critic_1_target = copy.deepcopy(self.critic_1)
        self.critic_1_optimizer = tf.optimizers.Adam(learning_rate=0.002)

        self.critic_2 = CriticNetwork(self.state_dim, self.action_dim)
        self.critic_2_target = copy.deepcopy(self.critic_2)
        self.critic_2_optimizer = tf.optimizers.Adam(learning_rate=0.002)
update 方法中的优化关键步骤

        给目标动作添加噪声,这里的噪声是从正态分布中采样并进行裁剪后得到的。

        这与 DDPG 不同,DDPG 中直接使用目标 Actor 网络生成的动作,而 TD3 通过添加噪声并裁剪来探索更多动作可能性以及使 Q 值估计更平滑。

        计算两个Critic网络的目标Q值,取两者最小值作为最终的目标 Q 值,这种取最小值的方式能在一定程度上避免高估 Q 值的问题。

            # 计算两个Critic网络的目标Q值
            target_actions = self.actor_target(next_state_batch)
            noise = tf.clip_by_value(tf.random.normal(tf.shape(target_actions), stddev=self.policy_noise),
                                     -self.noise_clip, self.noise_clip)
            noisy_target_actions = tf.clip_by_value(target_actions + noise, -self.action_bound, self.action_bound)

            q1_target = self.critic_1_target(next_state_batch, noisy_target_actions)
            q2_target = self.critic_2_target(next_state_batch, noisy_target_actions)
            q_target = tf.minimum(q1_target, q2_target)
            target_q = reward_batch + (1 - done_batch) * self.gamma * q_target

计算两个Critic网络殴打预测Q值以及损失

            # 计算两个Critic网络的预测Q值及损失
            current_q1 = self.critic_1(state_batch, action_batch)
            current_q2 = self.critic_2(state_batch, action_batch)
            critic_1_loss = tf.reduce_mean(tf.square(target_q - current_q1))
            critic_2_loss = tf.reduce_mean(tf.square(target_q - current_q2))

        延迟更新策略网络Actor,给 Critic 网络足够时间去更准确地评估动作价值,从而让 Actor 网络基于更可靠的反馈进行策略调整,这区别于 DDPG 中每次迭代都会更新 Actor 网络的方式。

            # 延迟更新策略网络(Actor)
            #当self.policy_delay_count是self.policy_delay的倍数时,才更新Actor网络。
            if self.policy_delay_count % self.policy_delay == 0:
                actions = self.actor(state_batch)
                q1 = self.critic_1(state_batch, actions)
                actor_loss = -tf.reduce_mean(q1)

更新两个Critic网络

        critic_1_gradients = tape.gradient(critic_1_loss, self.critic_1.trainable_variables)
        self.critic_1_optimizer.apply_gradients(zip(critic_1_gradients, self.critic_1.trainable_variables))

        critic_2_gradients = tape.gradient(critic_2_loss, self.critic_2.trainable_variables)
        self.critic_2_optimizer.apply_gradients(zip(critic_2_gradients, self.critic_2.trainable_variables))

添加了网络噪声

self.policy_noise = 0.2  # 策略噪声系数,用于给目标动作添加噪声
self.noise_clip = 0.5  # 噪声裁剪范围
self.policy_delay = 2  # 策略更新延迟步数

总的来说,TD3 通过双 Critic 网络、目标动作添加噪声、策略更新延迟以及对应的一系列配套机制,在 DDPG 的基础上对连续动作空间的强化学习任务进行了优化,提升了训练的稳定性和最终策略的性能表现。


http://www.kler.cn/a/429934.html

相关文章:

  • CSS学习记录07
  • 专业135+总分400+华中科技大学824信号与系统考研经验华科电子信息与通信工程,真题,大纲,参考书。
  • Claude:高效智能的AI语言模型
  • 【日常记录-Java】查看Maven本地仓库的位置
  • MySQL | 尚硅谷 | 第12章_MySQL数据类型精讲
  • python爬虫--某房源网站验证码破解
  • React之react-redux的使用
  • 前后端无缝沟通:掌握API接口开发与调用的关键
  • 【JavaEE】Spring Boot 项目创建
  • C#和Java异同点
  • 数字化那点事:一文读懂云计算
  • 代码随想录算法训练营day37|动态规划part5
  • Netty 心跳机制与连接管理
  • flink-connector-mysql-cdc:02 mysql-cdc高级扩展
  • 无监督目标检测最新CVPR解读
  • 【网络安全资料文档】网络安全空间态势感知系统建设方案,网络安全数据采集建设方案(word原件)
  • scala的正则表达式的特殊规则
  • 深入探索Redis:数据结构解析与Spring Boot实战应用
  • 介绍8款开源网络安全产品
  • python数据分析之爬虫基础:requests详解