当前位置: 首页 > article >正文

人工智能-A* 算法与机器学习算法结合

以下将为你展示如何将 A* 算法与机器学习算法(这里以简单的神经网络为例)结合实现路径规划。我们会先使用 A* 算法生成一些路径规划数据,然后用这些数据训练一个简单的神经网络,让神经网络学习如何预测路径。最后,将训练好的神经网络应用到路径规划任务中,实现 A* 算法与机器学习算法的结合。

代码实现

import numpy as np
import heapq
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 地图表示
map_grid = np.array([
    [0, 0, 0, 0, 0],
    [0, 1, 1, 0, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 1, 1, 0],
    [0, 0, 0, 0, 0]
])

# A* 算法实现
class Node:
    def __init__(self, x, y, g=float('inf'), h=float('inf'), parent=None):
        self.x = x
        self.y = y
        self.g = g
        self.h = h
        self.f = g + h
        self.parent = parent

    def __lt__(self, other):
        return self.f < other.f

def heuristic(current, goal):
    return abs(current[0] - goal[0]) + abs(current[1] - goal[1])

def astar(grid, start, goal):
    rows, cols = grid.shape
    open_list = []
    closed_set = set()

    start_node = Node(start[0], start[1], g=0, h=heuristic(start, goal))
    heapq.heappush(open_list, start_node)

    while open_list:
        current_node = heapq.heappop(open_list)

        if (current_node.x, current_node.y) == goal:
            path = []
            while current_node:
                path.append((current_node.x, current_node.y))
                current_node = current_node.parent
            return path[::-1]

        closed_set.add((current_node.x, current_node.y))

        neighbors = [(0, 1), (0, -1), (1, 0), (-1, 0)]
        for dx, dy in neighbors:
            new_x, new_y = current_node.x + dx, current_node.y + dy

            if 0 <= new_x < rows and 0 <= new_y < cols and grid[new_x][new_y] == 0 and (new_x, new_y) not in closed_set:
                new_g = current_node.g + 1
                new_h = heuristic((new_x, new_y), goal)
                new_node = Node(new_x, new_y, g=new_g, h=new_h, parent=current_node)

                found = False
                for i, node in enumerate(open_list):
                    if node.x == new_x and node.y == new_y:
                        if new_g < node.g:
                            open_list[i] = new_node
                            heapq.heapify(open_list)
                        found = True
                        break

                if not found:
                    heapq.heappush(open_list, new_node)

    return None

# 生成训练数据
def generate_training_data(grid, num_samples):
    rows, cols = grid.shape
    inputs = []
    outputs = []
    for _ in range(num_samples):
        start = (np.random.randint(0, rows), np.random.randint(0, cols))
        goal = (np.random.randint(0, rows), np.random.randint(0, cols))
        path = astar(grid, start, goal)
        if path:
            input_data = np.zeros((rows, cols))
            input_data[start] = 1
            input_data[goal] = 2
            output_data = np.zeros((rows, cols))
            for point in path:
                output_data[point] = 1
            inputs.append(input_data.flatten())
            outputs.append(output_data.flatten())
    return np.array(inputs), np.array(outputs)

# 自定义数据集类
class PathDataset(Dataset):
    def __init__(self, inputs, outputs):
        self.inputs = torch.tensor(inputs, dtype=torch.float32)
        self.outputs = torch.tensor(outputs, dtype=torch.float32)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]

# 定义简单的神经网络模型
class PathNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(PathNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 训练神经网络
def train_model(model, dataloader, criterion, optimizer, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

# 主程序
if __name__ == "__main__":
    # 生成训练数据
    num_samples = 1000
    inputs, outputs = generate_training_data(map_grid, num_samples)

    # 创建数据集和数据加载器
    dataset = PathDataset(inputs, outputs)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # 初始化神经网络
    input_size = map_grid.size
    hidden_size = 128
    output_size = map_grid.size
    model = PathNet(input_size, hidden_size, output_size)

    # 定义损失函数和优化器
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 训练模型
    epochs = 10
    train_model(model, dataloader, criterion, optimizer, epochs)

    # 使用训练好的模型进行路径规划
    start = (0, 0)
    goal = (4, 4)
    input_data = np.zeros((map_grid.shape))
    input_data[start] = 1
    input_data[goal] = 2
    input_tensor = torch.tensor(input_data.flatten(), dtype=torch.float32).unsqueeze(0)
    output = model(input_tensor)
    output_path = output.detach().numpy().reshape(map_grid.shape)
    path_points = np.argwhere(output_path > 0.5)
    print("神经网络预测的路径点:", path_points)

代码解释

1. A* 算法部分
  • Node 类:用于表示地图中的节点,包含节点的坐标、g 值、h 值、f 值和父节点。
  • heuristic 函数:使用曼哈顿距离作为启发式函数,估计从当前节点到目标节点的代价。
  • astar 函数:实现 A* 算法的核心逻辑,通过维护开放列表和关闭列表,寻找从起点到终点的最短路径。
2. 数据生成部分
  • generate_training_data 函数:随机生成起点和终点,使用 A* 算法计算路径,将起点、终点信息作为输入,路径信息作为输出,生成训练数据。
3. 神经网络部分
  • PathDataset 类:自定义数据集类,用于封装训练数据。
  • PathNet 类:定义一个简单的全连接神经网络,包含一个隐藏层。
  • train_model 函数:使用均方误差损失函数和 Adam 优化器训练神经网络。
4. 主程序部分
  • 生成训练数据,创建数据集和数据加载器。
  • 初始化神经网络,定义损失函数和优化器。
  • 训练神经网络。
  • 使用训练好的神经网络进行路径规划,将起点和终点信息输入网络,输出预测的路径。

注意事项

  • 这里的神经网络是一个简单的示例,实际应用中可能需要更复杂的网络结构,如卷积神经网络(CNN),以更好地处理地图数据。
  • 训练数据的质量和数量对神经网络的性能有很大影响,可以尝试增加样本数量或使用更复杂的地图来提高模型的泛化能力。
  • 神经网络的预测结果可能不是最优路径,A* 算法可以保证找到最优路径,但神经网络可以在一定程度上提高路径规划的速度。

http://www.kler.cn/a/537780.html

相关文章:

  • 【LeetCode 刷题】贪心算法(4)-区间问题
  • windows蓝牙驱动开发-蓝牙无线电重置和恢复
  • 【场景题】架构优化 - 解耦Redis缓存与业务逻辑
  • 消费电子产品中的噪声对TPS54202的影响
  • 第 26 场 蓝桥入门赛
  • fs 文件系统模块
  • HTMLCSSJS
  • LeetCodeHot 100 第一天
  • ubuntu conda运行kivy时报“No matching FB config found”
  • java文件上传粗糙版
  • 《PYTHON语言程序设计》(2018版)1.20修改这道题,利用类的方式(二) 接近成功....(上)
  • 云原生后端|实践?
  • 安装指定版本的pnpm
  • vue知识补充
  • 多光谱技术在华为手机上的应用发展历史
  • Android 问题01_AGP_Kotlin_Compiler_Mapping
  • 地基JVM中的强引用、软引用、弱引用、虚引用的区别
  • 【高级架构师】多线程和高并发编程(一):线程的基础概念
  • Beta分布
  • 深入解析:React 事件处理的秘密与高效实践
  • STM32的HAL库开发---高级定时器
  • 【填坑】新能源汽车三电设计之常用半导体器件系统性介绍
  • 实在RPA案例|视源股份:驱动20+核心场景数字化升级,组织效能提升超80%
  • maven-依托管理
  • 使用springAI实现图片相识度搜索
  • 标准模版——添加定时器功能模块