当前位置: 首页 > article >正文

Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法

下面是一个使用Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法的三个参数(sigma0rho0theta)的示例代码。该算法将依据障碍物环境进行优化。

实现思路

  1. 环境定义:定义一个包含障碍物的环境,用于模拟路径规划问题。
  2. TD3算法:使用TD3算法来学习如何优化路径规划算法的三个参数。
  3. 训练过程:在环境中进行训练,不断更新策略网络和价值网络。

代码示例

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# 定义TD3网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = self.max_action * torch.tanh(self.fc3(x))
        return x

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        # Q1架构
        self.fc1 = nn.Linear(state_dim + action_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)
        # Q2架构
        self.fc4 = nn.Linear(state_dim + action_dim, 400)
        self.fc5 = nn.Linear(400, 300)
        self.fc6 = nn.Linear(300, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        # Q1
        q1 = torch.relu(self.fc1(sa))
        q1 = torch.relu(self.fc2(q1))
        q1 = self.fc3(q1)
        # Q2
        q2 = torch.relu(self.fc4(sa))
        q2 = torch.relu(self.fc5(q2))
        q2 = self.fc6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = torch.relu(self.fc1(sa))
        q1 = torch.relu(self.fc2(q1))
        q1 = self.fc3(q1)
        return q1

# TD3算法类
class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = Actor(state_dim, action_dim, max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.gamma = 0.99
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_freq = 2

        self.total_it = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1))
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, replay_buffer, batch_size=100):
        self.total_it += 1
        # 从回放缓冲区采样
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            # 选择动作并添加噪声
            noise = (
                torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            next_action = (
                self.actor_target(next_state) + noise
            ).clamp(-self.max_action, self.max_action)

            # 计算目标Q值
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.gamma * target_Q

        # 获取当前Q估计值
        current_Q1, current_Q2 = self.critic(state, action)

        # 计算批评损失
        critic_loss = nn.MSELoss()(current_Q1, target_Q) + nn.MSELoss()(current_Q2, target_Q)

        # 优化批评网络
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # 延迟策略更新
        if self.total_it % self.policy_freq == 0:
            # 计算演员损失
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

            # 优化演员网络
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # 软更新目标网络
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

# 回放缓冲区类
class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, next_state, reward, done):
        self.buffer.append((state, action, next_state, reward, 1 - done))

    def sample(self, batch_size):
        state, action, next_state, reward, not_done = zip(*random.sample(self.buffer, batch_size))
        return torch.FloatTensor(state), torch.FloatTensor(action), torch.FloatTensor(next_state), torch.FloatTensor(reward).unsqueeze(1), torch.FloatTensor(not_done).unsqueeze(1)

    def __len__(self):
        return len(self.buffer)

# 模拟路径规划环境
class PathPlanningEnv:
    def __init__(self):
        # 简单模拟障碍物环境,这里用一个二维数组表示
        self.obstacles = np.random.randint(0, 2, (10, 10))
        self.state_dim = 10 * 10  # 环境状态维度
        self.action_dim = 3  # 三个参数 sigma0, rho0, theta
        self.max_action = 1.0

    def reset(self):
        # 重置环境
        return self.obstacles.flatten()

    def step(self, action):
        sigma0, rho0, theta = action
        # 简单模拟奖励计算,这里可以根据实际路径规划算法修改
        reward = np.random.randn()
        done = False
        next_state = self.obstacles.flatten()
        return next_state, reward, done

# 主训练循环
def main():
    env = PathPlanningEnv()
    state_dim = env.state_dim
    action_dim = env.action_dim
    max_action = env.max_action

    td3 = TD3(state_dim, action_dim, max_action)
    replay_buffer = ReplayBuffer(max_size=1000000)

    total_steps = 10000
    episode_steps = 0
    state = env.reset()

    for step in range(total_steps):
        episode_steps += 1
        # 选择动作
        action = td3.select_action(state)

        # 执行动作
        next_state, reward, done = env.step(action)

        # 将经验添加到回放缓冲区
        replay_buffer.add(state, action, next_state, reward, done)

        # 训练TD3
        if len(replay_buffer) > 100:
            td3.train(replay_buffer)

        state = next_state

        if done or episode_steps >= 100:
            state = env.reset()
            episode_steps = 0

    # 输出最终优化的参数
    final_state = env.reset()
    final_action = td3.select_action(final_state)
    sigma0, rho0, theta = final_action
    print(f"Optimized sigma0: {sigma0}, rho0: {rho0}, theta: {theta}")

if __name__ == "__main__":
    main()

代码解释

  1. 网络定义:定义了 ActorCritic 网络,分别用于生成动作和评估动作的价值。
  2. TD3类:实现了TD3算法的核心逻辑,包括动作选择、训练和目标网络的软更新。
  3. ReplayBuffer类:用于存储和采样经验数据。
  4. PathPlanningEnv类:模拟了一个包含障碍物的路径规划环境,提供了重置和执行动作的方法。
  5. 主训练循环:在环境中进行训练,不断更新策略网络和价值网络。

注意事项

  • 此示例中的奖励计算是简单模拟的,实际应用中需要根据具体的路径规划算法进行修改。
  • 障碍物环境的表示可以根据实际需求进行调整。

http://www.kler.cn/a/527244.html

相关文章:

  • java求职学习day23
  • 安卓(android)订餐菜单【Android移动开发基础案例教程(第2版)黑马程序员】
  • Leetcode 45. 跳跃游戏 II
  • 【漫话机器学习系列】067.希腊字母(greek letters)-写法、名称、读法和常见用途
  • floodfill算法(6题)
  • 全程Kali linux---CTFshow misc入门(14-24)
  • 第05章 17 Contour 过滤器介绍与例子
  • yolov11、yolov8部署的7种方法(yolov11、yolov8部署rknn的7种方法),一天一种部署方法,7天入门部署
  • Java中的getInterfaces()方法:使用与原理详解
  • 寒武纪MLU370部署deepseek r1
  • 【Java计算机毕业设计】基于Springboot的物业信息管理系统【源代码+数据库+LW文档+开题报告+答辩稿+部署教程+代码讲解】
  • 一起学SysML v2规范(01)
  • 【Vite + Vue + Ts 项目三个 tsconfig 文件】
  • Github 2025-01-31Java开源项目日报 Top10
  • 国产之光DeepSeek架构理解与应用分析
  • 【llm对话系统】大模型 Llama 源码分析之 Flash Attention
  • OPENGLPG第九版学习
  • Baklib对比其他知识管理工具的优势及应用效果全面分析
  • 数模测评:doubao1.5>deepseek-v3>gpt-o1
  • C# 实现 “Hello World” 教程
  • 37. RGBLCD实验
  • 最新Python大数据之Python基础【十】学生管理系统面向对象版_python面向对象学生管理系统
  • JAVA实战开源项目:网上购物商城(Vue+SpringBoot) 附源码
  • 随笔 | 写在一月的最后一天
  • Vue-cli 脚手架搭建
  • 翻译: Anthropic CEO:DeepSeek-R1是人工智能领域的革命吗?一